diff --git a/cmd/chatlog/cmd_decrypt.go b/cmd/chatlog/cmd_decrypt.go index 5692a6b..24d31ea 100644 --- a/cmd/chatlog/cmd_decrypt.go +++ b/cmd/chatlog/cmd_decrypt.go @@ -6,7 +6,7 @@ import ( "github.com/sjzar/chatlog/internal/chatlog" - log "github.com/sirupsen/logrus" + "github.com/rs/zerolog/log" "github.com/spf13/cobra" ) @@ -33,11 +33,11 @@ var decryptCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { m, err := chatlog.New("") if err != nil { - log.Error(err) + log.Err(err).Msg("failed to create chatlog instance") return } if err := m.CommandDecrypt(dataDir, workDir, key, decryptPlatform, decryptVer); err != nil { - log.Error(err) + log.Err(err).Msg("failed to decrypt") return } fmt.Println("decrypt success") diff --git a/cmd/chatlog/cmd_key.go b/cmd/chatlog/cmd_key.go index 25336fc..9bc1c51 100644 --- a/cmd/chatlog/cmd_key.go +++ b/cmd/chatlog/cmd_key.go @@ -5,7 +5,7 @@ import ( "github.com/sjzar/chatlog/internal/chatlog" - log "github.com/sirupsen/logrus" + "github.com/rs/zerolog/log" "github.com/spf13/cobra" ) @@ -21,12 +21,12 @@ var keyCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { m, err := chatlog.New("") if err != nil { - log.Error(err) + log.Err(err).Msg("failed to create chatlog instance") return } ret, err := m.CommandKey(pid) if err != nil { - log.Error(err) + log.Err(err).Msg("failed to get key") return } fmt.Println(ret) diff --git a/cmd/chatlog/log.go b/cmd/chatlog/log.go index 54b15ab..f2eea76 100644 --- a/cmd/chatlog/log.go +++ b/cmd/chatlog/log.go @@ -1,33 +1,26 @@ package chatlog import ( - "fmt" "io" "os" - "path" "path/filepath" - "runtime" + "time" "github.com/sjzar/chatlog/pkg/util" - log "github.com/sirupsen/logrus" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "github.com/sirupsen/logrus" "github.com/spf13/cobra" ) var Debug bool func initLog(cmd *cobra.Command, args []string) { - log.SetFormatter(&log.TextFormatter{ - FullTimestamp: true, - CallerPrettyfier: func(f *runtime.Frame) (string, string) { - _, filename := path.Split(f.File) - return "", fmt.Sprintf("%s:%d", filename, f.Line) - }, - }) + zerolog.SetGlobalLevel(zerolog.InfoLevel) if Debug { - log.SetLevel(log.DebugLevel) - log.SetReportCaller(true) + zerolog.SetGlobalLevel(zerolog.DebugLevel) } } @@ -43,8 +36,8 @@ func initTuiLog(cmd *cobra.Command, args []string) { panic(err) } logOutput = logFD - log.SetReportCaller(true) } - log.SetOutput(logOutput) + log.Logger = log.Output(zerolog.ConsoleWriter{Out: logOutput, NoColor: true, TimeFormat: time.RFC3339}) + logrus.SetOutput(logOutput) } diff --git a/cmd/chatlog/root.go b/cmd/chatlog/root.go index c360476..477576a 100644 --- a/cmd/chatlog/root.go +++ b/cmd/chatlog/root.go @@ -3,7 +3,7 @@ package chatlog import ( "github.com/sjzar/chatlog/internal/chatlog" - log "github.com/sirupsen/logrus" + "github.com/rs/zerolog/log" "github.com/spf13/cobra" ) @@ -17,7 +17,7 @@ func init() { func Execute() { if err := rootCmd.Execute(); err != nil { - log.Error(err) + log.Err(err).Msg("command execution failed") } } @@ -38,11 +38,11 @@ func Root(cmd *cobra.Command, args []string) { m, err := chatlog.New("") if err != nil { - log.Error(err) + log.Err(err).Msg("failed to create chatlog instance") return } if err := m.Run(); err != nil { - log.Error(err) + log.Err(err).Msg("failed to run chatlog instance") } } diff --git a/go.mod b/go.mod index c18e687..b894c0f 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/mattn/go-sqlite3 v1.14.24 github.com/pierrec/lz4/v4 v4.1.22 github.com/rivo/tview v0.0.0-20250325173046-7b72abf45814 + github.com/rs/zerolog v1.34.0 github.com/shirou/gopsutil/v4 v4.25.2 github.com/sirupsen/logrus v1.9.3 github.com/spf13/cobra v1.9.1 @@ -41,6 +42,7 @@ require ( github.com/leodido/go-urn v1.4.0 // indirect github.com/lucasb-eyer/go-colorful v1.2.0 // indirect github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35 // indirect + github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-runewidth v0.0.16 // indirect github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect diff --git a/go.sum b/go.sum index b0e5eda..98ebcbb 100644 --- a/go.sum +++ b/go.sum @@ -6,6 +6,7 @@ github.com/bytedance/sonic/loader v0.2.4/go.mod h1:N8A3vUdtUebEY2/VQC0MyhYeKUFos github.com/cloudwego/base64x v0.1.5 h1:XPciSp1xaq2VCSt6lF0phncD4koWyULpl5bUxbfCyP4= github.com/cloudwego/base64x v0.1.5/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= @@ -41,6 +42,7 @@ github.com/go-viper/mapstructure/v2 v2.2.1 h1:ZAaOCxANMuZx5RCeg0mBdEZk7DZasvvZIx github.com/go-viper/mapstructure/v2 v2.2.1/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/goccy/go-json v0.10.5 h1:Fq85nIqj+gXn/S5ahsiTlK3TmC85qgirsdTP/+DeaC4= github.com/goccy/go-json v0.10.5/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= @@ -68,6 +70,10 @@ github.com/lucasb-eyer/go-colorful v1.2.0 h1:1nnpGOrhyZZuNyfu1QjKiUICQ74+3FNCN69 github.com/lucasb-eyer/go-colorful v1.2.0/go.mod h1:R4dSotOR9KMtayYi1e77YzuveK+i7ruzyGqttikkLy0= github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35 h1:PpXWgLPs+Fqr325bN2FD2ISlRRztXibcX6e8f5FR5Dc= github.com/lufia/plan9stats v0.0.0-20250317134145-8bc96cf8fc35/go.mod h1:autxFIvghDt3jPTLoqZ9OZ7s9qTGNAWmYCjVFWPX/zg= +github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= +github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.16 h1:E5ScNMtiwvlvB5paMFdw9p4kSQzbXFikJ5SQO6TULQc= @@ -83,6 +89,7 @@ github.com/pelletier/go-toml/v2 v2.2.3 h1:YmeHyLY8mFWbdkNWwpr+qIL2bEqT0o95WSdkNH github.com/pelletier/go-toml/v2 v2.2.3/go.mod h1:MfCQTFTvCcUyyvvwm1+G6H/jORL20Xlb6rzQu9GuUkc= github.com/pierrec/lz4/v4 v4.1.22 h1:cKFw6uJDK+/gfw5BcDL0JL5aBsAFdsIT18eRtLj7VIU= github.com/pierrec/lz4/v4 v4.1.22/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 h1:o4JXh1EVt9k/+g42oCprj/FisM4qX9L3sZB3upGN2ZU= @@ -95,6 +102,9 @@ github.com/rivo/uniseg v0.4.7 h1:WUdvkW8uEhrYfLC4ZzdpI2ztxP1I582+49Oc5Mq64VQ= github.com/rivo/uniseg v0.4.7/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.9.0 h1:GbgQGNtTrEmddYDSAH9QLRyfAHY12md+8YFTqyMTC9k= github.com/sagikazarmark/locafero v0.9.0/go.mod h1:UBUyz37V+EdMS3hDF3QWIiVr/2dPrx49OMO0Bn0hJqk= @@ -180,6 +190,7 @@ golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/chatlog/http/route.go b/internal/chatlog/http/route.go index 4c9a781..8e06703 100644 --- a/internal/chatlog/http/route.go +++ b/internal/chatlog/http/route.go @@ -88,7 +88,7 @@ func (s *Service) GetChatlog(c *gin.Context) { var err error start, end, ok := util.TimeRangeOf(q.Time) if !ok { - errors.Err(c, errors.ErrInvalidArg("time")) + errors.Err(c, errors.InvalidArg("time")) } if q.Limit < 0 { q.Limit = 0 @@ -276,7 +276,7 @@ func (s *Service) GetFile(c *gin.Context) { func (s *Service) GetMedia(c *gin.Context, _type string) { key := c.Param("key") if key == "" { - errors.Err(c, errors.ErrInvalidArg(key)) + errors.Err(c, errors.InvalidArg(key)) return } diff --git a/internal/chatlog/http/service.go b/internal/chatlog/http/service.go index f027ec4..08d04c1 100644 --- a/internal/chatlog/http/service.go +++ b/internal/chatlog/http/service.go @@ -11,7 +11,7 @@ import ( "github.com/sjzar/chatlog/internal/errors" "github.com/gin-gonic/gin" - log "github.com/sirupsen/logrus" + "github.com/rs/zerolog/log" ) const ( @@ -33,14 +33,14 @@ func NewService(ctx *ctx.Context, db *database.Service, mcp *mcp.Service) *Servi // Handle error from SetTrustedProxies if err := router.SetTrustedProxies(nil); err != nil { - log.Error("Failed to set trusted proxies:", err) + log.Err(err).Msg("Failed to set trusted proxies") } // Middleware router.Use( errors.RecoveryMiddleware(), errors.ErrorHandlerMiddleware(), - gin.LoggerWithWriter(log.StandardLogger().Out), + gin.LoggerWithWriter(log.Logger), ) s := &Service{ @@ -68,11 +68,11 @@ func (s *Service) Start() error { go func() { // Handle error from Run if err := s.server.ListenAndServe(); err != nil { - log.Error("Server Stopped: ", err) + log.Err(err).Msg("Failed to start HTTP server") } }() - log.Info("Server started on ", s.ctx.HTTPAddr) + log.Info().Msg("Starting HTTP server on " + s.ctx.HTTPAddr) return nil } @@ -88,10 +88,10 @@ func (s *Service) Stop() error { defer cancel() if err := s.server.Shutdown(ctx); err != nil { - return errors.HTTP("HTTP server shutdown error", err) + return errors.HTTPShutDown(err) } - log.Info("HTTP server stopped") + log.Info().Msg("HTTP server stopped") return nil } diff --git a/internal/chatlog/wechat/service.go b/internal/chatlog/wechat/service.go index 584970b..cdba4b8 100644 --- a/internal/chatlog/wechat/service.go +++ b/internal/chatlog/wechat/service.go @@ -13,7 +13,7 @@ import ( "github.com/sjzar/chatlog/internal/wechat/decrypt" "github.com/sjzar/chatlog/pkg/util" - log "github.com/sirupsen/logrus" + "github.com/rs/zerolog/log" ) type Service struct { @@ -64,7 +64,7 @@ func (s *Service) FindDBFiles(rootDir string, recursive bool) ([]string, error) walkFunc := func(path string, info os.FileInfo, err error) error { if err != nil { // If a file or directory can't be accessed, log the error but continue - fmt.Printf("Warning: Cannot access %s: %v\n", path, err) + log.Err(err).Msgf("Warning: Cannot access %s", path) return nil } @@ -121,7 +121,7 @@ func (s *Service) DecryptDBFiles(dataDir string, workDir string, key string, pla defer outputFile.Close() if err := decryptor.Decrypt(ctx, dbfile, key, outputFile); err != nil { - log.Debugf("failed to decrypt %s: %v", dbfile, err) + log.Err(err).Msgf("failed to decrypt %s", dbfile) if err == errors.ErrAlreadyDecrypted { if data, err := os.ReadFile(dbfile); err == nil { outputFile.Write(data) diff --git a/internal/errors/domain_errors.go b/internal/errors/domain_errors.go deleted file mode 100644 index 457b303..0000000 --- a/internal/errors/domain_errors.go +++ /dev/null @@ -1,151 +0,0 @@ -package errors - -import ( - "fmt" - "net/http" -) - -// 微信相关错误 - -// WeChatProcessNotFound 创建微信进程未找到错误 -func WeChatProcessNotFound() *AppError { - return New(ErrTypeWeChat, "wechat process not found", nil, http.StatusNotFound).WithStack() -} - -// WeChatKeyExtractFailed 创建微信密钥提取失败错误 -func WeChatKeyExtractFailed(cause error) *AppError { - return New(ErrTypeWeChat, "failed to extract wechat key", cause, http.StatusInternalServerError).WithStack() -} - -// WeChatDecryptFailed 创建微信解密失败错误 -func WeChatDecryptFailed(cause error) *AppError { - return New(ErrTypeWeChat, "failed to decrypt wechat database", cause, http.StatusInternalServerError).WithStack() -} - -// WeChatAccountNotSelected 创建未选择微信账号错误 -func WeChatAccountNotSelected() *AppError { - return New(ErrTypeWeChat, "no wechat account selected", nil, http.StatusBadRequest).WithStack() -} - -// 数据库相关错误 - -// DBConnectionFailed 创建数据库连接失败错误 -func DBConnectionFailed(cause error) *AppError { - return New(ErrTypeDatabase, "database connection failed", cause, http.StatusInternalServerError).WithStack() -} - -// DBQueryFailed 创建数据库查询失败错误 -func DBQueryFailed(operation string, cause error) *AppError { - return New(ErrTypeDatabase, fmt.Sprintf("database query failed: %s", operation), cause, http.StatusInternalServerError).WithStack() -} - -// DBRecordNotFound 创建数据库记录未找到错误 -func DBRecordNotFound(resource string) *AppError { - return New(ErrTypeNotFound, fmt.Sprintf("record not found: %s", resource), nil, http.StatusNotFound).WithStack() -} - -// 配置相关错误 - -// ConfigInvalid 创建配置无效错误 -func ConfigInvalid(field string, cause error) *AppError { - return New(ErrTypeConfig, fmt.Sprintf("invalid configuration: %s", field), cause, http.StatusInternalServerError).WithStack() -} - -// ConfigMissing 创建配置缺失错误 -func ConfigMissing(field string) *AppError { - return New(ErrTypeConfig, fmt.Sprintf("missing configuration: %s", field), nil, http.StatusBadRequest).WithStack() -} - -// 平台相关错误 - -// PlatformUnsupported 创建不支持的平台错误 -func PlatformUnsupported(platform string, version int) *AppError { - return New(ErrTypeInvalidArg, fmt.Sprintf("unsupported platform: %s v%d", platform, version), nil, http.StatusBadRequest).WithStack() -} - -// 文件系统错误 - -// FileNotFound 创建文件未找到错误 -func FileNotFound(path string) *AppError { - return New(ErrTypeNotFound, fmt.Sprintf("file not found: %s", path), nil, http.StatusNotFound).WithStack() -} - -// FileReadFailed 创建文件读取失败错误 -func FileReadFailed(path string, cause error) *AppError { - return New(ErrTypeInternal, fmt.Sprintf("failed to read file: %s", path), cause, http.StatusInternalServerError).WithStack() -} - -// FileWriteFailed 创建文件写入失败错误 -func FileWriteFailed(path string, cause error) *AppError { - return New(ErrTypeInternal, fmt.Sprintf("failed to write file: %s", path), cause, http.StatusInternalServerError).WithStack() -} - -// 参数验证错误 - -// RequiredParam 创建必需参数缺失错误 -func RequiredParam(param string) *AppError { - return New(ErrTypeInvalidArg, fmt.Sprintf("required parameter missing: %s", param), nil, http.StatusBadRequest).WithStack() -} - -// InvalidParam 创建参数无效错误 -func InvalidParam(param string, reason string) *AppError { - message := fmt.Sprintf("invalid parameter: %s", param) - if reason != "" { - message = fmt.Sprintf("%s (%s)", message, reason) - } - return New(ErrTypeInvalidArg, message, nil, http.StatusBadRequest).WithStack() -} - -// 解密相关错误 - -// DecryptInvalidKey 创建无效密钥格式错误 -func DecryptInvalidKey(cause error) *AppError { - return New(ErrTypeWeChat, "invalid key format", cause, http.StatusBadRequest). - WithStack() -} - -// DecryptCreateCipherFailed 创建无法创建加密器错误 -func DecryptCreateCipherFailed(cause error) *AppError { - return New(ErrTypeWeChat, "failed to create cipher", cause, http.StatusInternalServerError). - WithStack() -} - -// DecryptDecodeKeyFailed 创建无法解码十六进制密钥错误 -func DecryptDecodeKeyFailed(cause error) *AppError { - return New(ErrTypeWeChat, "failed to decode hex key", cause, http.StatusBadRequest). - WithStack() -} - -// DecryptWriteOutputFailed 创建无法写入输出错误 -func DecryptWriteOutputFailed(cause error) *AppError { - return New(ErrTypeWeChat, "failed to write decryption output", cause, http.StatusInternalServerError). - WithStack() -} - -// DecryptOperationCanceled 创建解密操作被取消错误 -func DecryptOperationCanceled() *AppError { - return New(ErrTypeWeChat, "decryption operation was canceled", nil, http.StatusBadRequest). - WithStack() -} - -// DecryptOpenFileFailed 创建无法打开数据库文件错误 -func DecryptOpenFileFailed(path string, cause error) *AppError { - return New(ErrTypeWeChat, fmt.Sprintf("failed to open database file: %s", path), cause, http.StatusInternalServerError). - WithStack() -} - -// DecryptReadFileFailed 创建无法读取数据库文件错误 -func DecryptReadFileFailed(path string, cause error) *AppError { - return New(ErrTypeWeChat, fmt.Sprintf("failed to read database file: %s", path), cause, http.StatusInternalServerError). - WithStack() -} - -// DecryptIncompleteRead 创建不完整的头部读取错误 -func DecryptIncompleteRead(cause error) *AppError { - return New(ErrTypeWeChat, "incomplete header read during decryption", cause, http.StatusInternalServerError). - WithStack() -} - -var ErrAlreadyDecrypted = New(ErrTypeWeChat, "database file is already decrypted", nil, http.StatusBadRequest) -var ErrDecryptHashVerificationFailed = New(ErrTypeWeChat, "hash verification failed during decryption", nil, http.StatusBadRequest) -var ErrDecryptIncorrectKey = New(ErrTypeWeChat, "incorrect decryption key", nil, http.StatusBadRequest) diff --git a/internal/errors/errors.go b/internal/errors/errors.go index d14b03f..d6d80bd 100644 --- a/internal/errors/errors.go +++ b/internal/errors/errors.go @@ -10,51 +10,29 @@ import ( "github.com/gin-gonic/gin" ) -// 定义错误类型常量 -const ( - ErrTypeDatabase = "database" - ErrTypeWeChat = "wechat" - ErrTypeHTTP = "http" - ErrTypeConfig = "config" - ErrTypeInvalidArg = "invalid_argument" - ErrTypeAuth = "authentication" - ErrTypePermission = "permission" - ErrTypeNotFound = "not_found" - ErrTypeValidation = "validation" - ErrTypeRateLimit = "rate_limit" - ErrTypeInternal = "internal" -) - -// AppError 表示应用程序错误 -type AppError struct { - Type string `json:"type"` // 错误类型 - Message string `json:"message"` // 错误消息 - Cause error `json:"-"` // 原始错误 - Code int `json:"-"` // HTTP Code - Stack []string `json:"-"` // 错误堆栈 - RequestID string `json:"request_id,omitempty"` // 请求ID,用于跟踪 +type Error struct { + Message string `json:"message"` // 错误消息 + Cause error `json:"-"` // 原始错误 + Code int `json:"-"` // HTTP Code + Stack []string `json:"-"` // 错误堆栈 } -// Error 实现 error 接口 -func (e *AppError) Error() string { +func (e *Error) Error() string { if e.Cause != nil { - return fmt.Sprintf("%s: %s: %v", e.Type, e.Message, e.Cause) + return fmt.Sprintf("%s: %v", e.Message, e.Cause) } - return fmt.Sprintf("%s: %s", e.Type, e.Message) + return fmt.Sprintf("%s", e.Message) } -// String 返回错误的字符串表示 -func (e *AppError) String() string { +func (e *Error) String() string { return e.Error() } -// Unwrap 实现 errors.Unwrap 接口,用于错误链 -func (e *AppError) Unwrap() error { +func (e *Error) Unwrap() error { return e.Cause } -// WithStack 添加堆栈信息到错误 -func (e *AppError) WithStack() *AppError { +func (e *Error) WithStack() *Error { const depth = 32 var pcs [depth]uintptr n := runtime.Callers(2, pcs[:]) @@ -75,32 +53,29 @@ func (e *AppError) WithStack() *AppError { return e } -// WithRequestID 添加请求ID到错误 -func (e *AppError) WithRequestID(requestID string) *AppError { - e.RequestID = requestID - return e -} - -// New 创建新的应用错误 -func New(errType, message string, cause error, code int) *AppError { - return &AppError{ - Type: errType, +func New(cause error, code int, message string) *Error { + return &Error{ Message: message, Cause: cause, Code: code, } } -// Wrap 包装现有错误为 AppError -func Wrap(err error, errType, message string, code int) *AppError { +func Newf(cause error, code int, format string, args ...interface{}) *Error { + return &Error{ + Message: fmt.Sprintf(format, args...), + Cause: cause, + Code: http.StatusInternalServerError, + } +} + +func Wrap(err error, message string, code int) *Error { if err == nil { return nil } - // 如果已经是 AppError,保留原始类型但更新消息 - if appErr, ok := err.(*AppError); ok { - return &AppError{ - Type: appErr.Type, + if appErr, ok := err.(*Error); ok { + return &Error{ Message: message, Cause: appErr.Cause, Code: appErr.Code, @@ -108,44 +83,15 @@ func Wrap(err error, errType, message string, code int) *AppError { } } - return New(errType, message, err, code) + return New(err, code, message) } -// Is 检查错误是否为特定类型 -func Is(err error, errType string) bool { - if err == nil { - return false - } - - var appErr *AppError - if errors.As(err, &appErr) { - return appErr.Type == errType - } - - return false -} - -// GetType 获取错误类型 -func GetType(err error) string { - if err == nil { - return "" - } - - var appErr *AppError - if errors.As(err, &appErr) { - return appErr.Type - } - - return "unknown" -} - -// GetCode 获取错误的 HTTP 状态码 func GetCode(err error) int { if err == nil { return http.StatusOK } - var appErr *AppError + var appErr *Error if errors.As(err, &appErr) { return appErr.Code } @@ -153,7 +99,6 @@ func GetCode(err error) int { return http.StatusInternalServerError } -// RootCause 获取错误链中的根本原因 func RootCause(err error) error { for err != nil { unwrapped := errors.Unwrap(err) @@ -165,81 +110,11 @@ func RootCause(err error) error { return err } -// ErrInvalidArg 无效参数错误 -func ErrInvalidArg(param string) *AppError { - return New(ErrTypeInvalidArg, fmt.Sprintf("invalid arg: %s", param), nil, http.StatusBadRequest).WithStack() -} - -// Database 创建数据库错误 -func Database(message string, cause error) *AppError { - return New(ErrTypeDatabase, message, cause, http.StatusInternalServerError).WithStack() -} - -// WeChat 创建微信相关错误 -func WeChat(message string, cause error) *AppError { - return New(ErrTypeWeChat, message, cause, http.StatusInternalServerError).WithStack() -} - -// HTTP 创建HTTP服务错误 -func HTTP(message string, cause error) *AppError { - return New(ErrTypeHTTP, message, cause, http.StatusInternalServerError).WithStack() -} - -// Config 创建配置错误 -func Config(message string, cause error) *AppError { - return New(ErrTypeConfig, message, cause, http.StatusInternalServerError).WithStack() -} - -// NotFound 创建资源不存在错误 -func NotFound(resource string, cause error) *AppError { - message := fmt.Sprintf("resource not found: %s", resource) - return New(ErrTypeNotFound, message, cause, http.StatusNotFound).WithStack() -} - -// Unauthorized 创建未授权错误 -func Unauthorized(message string, cause error) *AppError { - return New(ErrTypeAuth, message, cause, http.StatusUnauthorized).WithStack() -} - -// Forbidden 创建权限不足错误 -func Forbidden(message string, cause error) *AppError { - return New(ErrTypePermission, message, cause, http.StatusForbidden).WithStack() -} - -// Validation 创建数据验证错误 -func Validation(message string, cause error) *AppError { - return New(ErrTypeValidation, message, cause, http.StatusBadRequest).WithStack() -} - -// RateLimit 创建请求频率限制错误 -func RateLimit(message string, cause error) *AppError { - return New(ErrTypeRateLimit, message, cause, http.StatusTooManyRequests).WithStack() -} - -// Internal 创建内部服务器错误 -func Internal(message string, cause error) *AppError { - return New(ErrTypeInternal, message, cause, http.StatusInternalServerError).WithStack() -} - -// Err 在HTTP响应中返回错误 func Err(c *gin.Context, err error) { - // 获取请求ID(如果有) - requestID := c.GetString("RequestID") - - if appErr, ok := err.(*AppError); ok { - if requestID != "" { - appErr.RequestID = requestID - } + if appErr, ok := err.(*Error); ok { c.JSON(appErr.Code, appErr) return } - // 未知错误 - unknownErr := &AppError{ - Type: "unknown", - Message: err.Error(), - Code: http.StatusInternalServerError, - RequestID: requestID, - } - c.JSON(http.StatusInternalServerError, unknownErr) + c.JSON(http.StatusInternalServerError, err.Error()) } diff --git a/internal/errors/errors_test.go b/internal/errors/errors_test.go deleted file mode 100644 index e0a22a9..0000000 --- a/internal/errors/errors_test.go +++ /dev/null @@ -1,165 +0,0 @@ -package errors - -import ( - "fmt" - "net/http" - "testing" -) - -func TestErrorCreation(t *testing.T) { - // 测试创建基本错误 - err := New("test", "test message", nil, http.StatusBadRequest) - if err.Type != "test" || err.Message != "test message" || err.Code != http.StatusBadRequest { - t.Errorf("New() created incorrect error: %v", err) - } - - // 测试创建带原因的错误 - cause := fmt.Errorf("original error") - err = New("test", "test with cause", cause, http.StatusInternalServerError) - if err.Cause != cause { - t.Errorf("New() did not set cause correctly: %v", err) - } - - // 测试错误消息格式 - expected := "test: test with cause: original error" - if err.Error() != expected { - t.Errorf("Error() = %q, want %q", err.Error(), expected) - } -} - -func TestErrorWrapping(t *testing.T) { - // 测试包装普通错误 - original := fmt.Errorf("original error") - wrapped := Wrap(original, "wrapped", "wrapped message", http.StatusBadRequest) - - if wrapped.Type != "wrapped" || wrapped.Message != "wrapped message" { - t.Errorf("Wrap() created incorrect error: %v", wrapped) - } - - if wrapped.Cause != original { - t.Errorf("Wrap() did not set cause correctly") - } - - // 测试包装 AppError - appErr := New("app", "app error", nil, http.StatusNotFound) - rewrapped := Wrap(appErr, "ignored", "new message", http.StatusBadRequest) - - if rewrapped.Type != "app" { - t.Errorf("Wrap() did not preserve original AppError type: got %s, want %s", - rewrapped.Type, appErr.Type) - } - - if rewrapped.Message != "new message" { - t.Errorf("Wrap() did not update message: got %s, want %s", - rewrapped.Message, "new message") - } - - if rewrapped.Code != appErr.Code { - t.Errorf("Wrap() did not preserve original status code: got %d, want %d", - rewrapped.Code, appErr.Code) - } -} - -func TestErrorTypeChecking(t *testing.T) { - // 创建不同类型的错误 - dbErr := Database("db error", nil) - httpErr := HTTP("http error", nil) - - // 测试 Is 函数 - if !Is(dbErr, ErrTypeDatabase) { - t.Errorf("Is() failed to identify database error") - } - - if Is(dbErr, ErrTypeHTTP) { - t.Errorf("Is() incorrectly identified database error as HTTP error") - } - - if !Is(httpErr, ErrTypeHTTP) { - t.Errorf("Is() failed to identify HTTP error") - } - - // 测试 GetType 函数 - if GetType(dbErr) != ErrTypeDatabase { - t.Errorf("GetType() returned incorrect type: got %s, want %s", - GetType(dbErr), ErrTypeDatabase) - } - - if GetType(httpErr) != ErrTypeHTTP { - t.Errorf("GetType() returned incorrect type: got %s, want %s", - GetType(httpErr), ErrTypeHTTP) - } - - // 测试普通错误 - stdErr := fmt.Errorf("standard error") - if GetType(stdErr) != "unknown" { - t.Errorf("GetType() for standard error should return 'unknown', got %s", - GetType(stdErr)) - } -} - -func TestErrorUnwrapping(t *testing.T) { - // 创建嵌套错误 - innermost := fmt.Errorf("innermost error") - inner := Wrap(innermost, "inner", "inner error", http.StatusBadRequest) - outer := Wrap(inner, "outer", "outer error", http.StatusInternalServerError) - - // 测试 Unwrap - if unwrapped := outer.Unwrap(); unwrapped != inner.Cause { - t.Errorf("Unwrap() did not return correct inner error") - } - - // 测试 RootCause - if root := RootCause(outer); root != innermost { - t.Errorf("RootCause() did not return innermost error") - } -} - -func TestErrorHelperFunctions(t *testing.T) { - // 测试辅助函数 - invalidArg := ErrInvalidArg("username") - if invalidArg.Type != ErrTypeInvalidArg { - t.Errorf("ErrInvalidArg() created error with wrong type: %s", invalidArg.Type) - } - - dbErr := Database("query failed", nil) - if dbErr.Type != ErrTypeDatabase { - t.Errorf("Database() created error with wrong type: %s", dbErr.Type) - } - - notFound := NotFound("user", nil) - if notFound.Type != ErrTypeNotFound || notFound.Code != http.StatusNotFound { - t.Errorf("NotFound() created error with wrong type or code: %s, %d", - notFound.Type, notFound.Code) - } -} - -func TestErrorUtilityFunctions(t *testing.T) { - // 测试 JoinErrors - err1 := fmt.Errorf("error 1") - err2 := fmt.Errorf("error 2") - - // 单个错误 - if joined := JoinErrors(err1); joined != err1 { - t.Errorf("JoinErrors() with single error should return that error") - } - - // 多个错误 - joined := JoinErrors(err1, err2) - if joined == nil { - t.Errorf("JoinErrors() returned nil for multiple errors") - } - - // nil 错误 - if joined := JoinErrors(nil, nil); joined != nil { - t.Errorf("JoinErrors() with all nil should return nil") - } - - // 测试 WrapIfErr - if wrapped := WrapIfErr(nil, "test", "message", http.StatusOK); wrapped != nil { - t.Errorf("WrapIfErr() with nil should return nil") - } - - if wrapped := WrapIfErr(err1, "test", "message", http.StatusBadRequest); wrapped == nil { - t.Errorf("WrapIfErr() with non-nil error should return non-nil") - } -} diff --git a/internal/errors/http_errors.go b/internal/errors/http_errors.go new file mode 100644 index 0000000..5e13227 --- /dev/null +++ b/internal/errors/http_errors.go @@ -0,0 +1,11 @@ +package errors + +import "net/http" + +func InvalidArg(arg string) error { + return Newf(nil, http.StatusBadRequest, "invalid argument: %s", arg) +} + +func HTTPShutDown(cause error) error { + return Newf(cause, http.StatusInternalServerError, "http server shut down") +} diff --git a/internal/errors/middleware.go b/internal/errors/middleware.go index 82cffa3..2965923 100644 --- a/internal/errors/middleware.go +++ b/internal/errors/middleware.go @@ -1,12 +1,11 @@ package errors import ( - "fmt" "net/http" "github.com/gin-gonic/gin" "github.com/google/uuid" - log "github.com/sirupsen/logrus" + "github.com/rs/zerolog/log" ) // ErrorHandlerMiddleware 是一个 Gin 中间件,用于统一处理请求过程中的错误 @@ -40,21 +39,18 @@ func RecoveryMiddleware() gin.HandlerFunc { return func(c *gin.Context) { defer func() { if r := recover(); r != nil { - // 获取请求 ID - requestID, _ := c.Get("RequestID") - requestIDStr, _ := requestID.(string) // 创建内部服务器错误 - var err *AppError + var err *Error switch v := r.(type) { case error: - err = Internal("panic recovered", v).WithRequestID(requestIDStr) + err = New(v, http.StatusInternalServerError, "panic recovered") default: - err = Internal(fmt.Sprintf("panic recovered: %v", r), nil).WithRequestID(requestIDStr) + err = Newf(nil, http.StatusInternalServerError, "panic recovered: %v", r) } // 记录错误日志 - log.Errorf("PANIC RECOVERED: %v\n", err) + log.Err(err).Msg("PANIC RECOVERED") // 返回 500 错误 c.JSON(http.StatusInternalServerError, err) diff --git a/internal/errors/os_errors.go b/internal/errors/os_errors.go new file mode 100644 index 0000000..1ca02d3 --- /dev/null +++ b/internal/errors/os_errors.go @@ -0,0 +1,23 @@ +package errors + +import "net/http" + +func OpenFileFailed(path string, cause error) *Error { + return Newf(cause, http.StatusInternalServerError, "failed to open file: %s", path).WithStack() +} + +func StatFileFailed(path string, cause error) *Error { + return Newf(cause, http.StatusInternalServerError, "failed to stat file: %s", path).WithStack() +} + +func ReadFileFailed(path string, cause error) *Error { + return Newf(cause, http.StatusInternalServerError, "failed to read file: %s", path).WithStack() +} + +func IncompleteRead(cause error) *Error { + return New(cause, http.StatusInternalServerError, "incomplete header read during decryption").WithStack() +} + +func WriteOutputFailed(cause error) *Error { + return New(cause, http.StatusInternalServerError, "failed to write output").WithStack() +} diff --git a/internal/errors/utils.go b/internal/errors/utils.go deleted file mode 100644 index 8cef21e..0000000 --- a/internal/errors/utils.go +++ /dev/null @@ -1,131 +0,0 @@ -package errors - -import ( - stderrors "errors" - "fmt" - "strings" -) - -// WrapIfErr 如果 err 不为 nil,则包装错误并返回,否则返回 nil -func WrapIfErr(err error, errType, message string, code int) error { - if err == nil { - return nil - } - return Wrap(err, errType, message, code) -} - -// JoinErrors 将多个错误合并为一个错误 -// 如果只有一个错误不为 nil,则返回该错误 -// 如果有多个错误不为 nil,则创建一个包含所有错误信息的新错误 -func JoinErrors(errs ...error) error { - var nonNilErrs []error - for _, err := range errs { - if err != nil { - nonNilErrs = append(nonNilErrs, err) - } - } - - if len(nonNilErrs) == 0 { - return nil - } - - if len(nonNilErrs) == 1 { - return nonNilErrs[0] - } - - // 合并多个错误 - var messages []string - for _, err := range nonNilErrs { - messages = append(messages, err.Error()) - } - - return Internal( - fmt.Sprintf("multiple errors occurred: %s", strings.Join(messages, "; ")), - nonNilErrs[0], - ) -} - -// IsNil 检查错误是否为 nil -func IsNil(err error) bool { - return err == nil -} - -// IsNotNil 检查错误是否不为 nil -func IsNotNil(err error) bool { - return err != nil -} - -// IsType 检查错误是否为指定类型 -func IsType(err error, errType string) bool { - return Is(err, errType) -} - -// HasCause 检查错误是否包含指定的原因 -func HasCause(err error, cause error) bool { - if err == nil || cause == nil { - return false - } - - var appErr *AppError - if stderrors.As(err, &appErr) { - if appErr.Cause == cause { - return true - } - return HasCause(appErr.Cause, cause) - } - - return err == cause -} - -// AsAppError 将错误转换为 AppError 类型 -func AsAppError(err error) (*AppError, bool) { - var appErr *AppError - if stderrors.As(err, &appErr) { - return appErr, true - } - return nil, false -} - -// FormatErrorChain 格式化错误链,便于调试 -func FormatErrorChain(err error) string { - if err == nil { - return "" - } - - var result strings.Builder - result.WriteString(err.Error()) - - // 获取 AppError 类型的堆栈信息 - var appErr *AppError - if stderrors.As(err, &appErr) && len(appErr.Stack) > 0 { - result.WriteString("\nStack Trace:\n") - for _, frame := range appErr.Stack { - result.WriteString(" ") - result.WriteString(frame) - result.WriteString("\n") - } - } - - // 递归处理错误链 - cause := stderrors.Unwrap(err) - if cause != nil { - result.WriteString("\nCaused by: ") - result.WriteString(FormatErrorChain(cause)) - } - - return result.String() -} - -// GetErrorDetails 返回错误的详细信息,包括类型、消息、HTTP状态码和请求ID -func GetErrorDetails(err error) (errType string, message string, code int, requestID string) { - if err == nil { - return "", "", 0, "" - } - - var appErr *AppError - if stderrors.As(err, &appErr) { - return appErr.Type, appErr.Message, appErr.Code, appErr.RequestID - } - - return "unknown", err.Error(), 500, "" -} diff --git a/internal/errors/wechat_errors.go b/internal/errors/wechat_errors.go new file mode 100644 index 0000000..f6dcad3 --- /dev/null +++ b/internal/errors/wechat_errors.go @@ -0,0 +1,65 @@ +package errors + +import "net/http" + +var ( + ErrAlreadyDecrypted = New(nil, http.StatusBadRequest, "database file is already decrypted") + ErrDecryptHashVerificationFailed = New(nil, http.StatusBadRequest, "hash verification failed during decryption") + ErrDecryptIncorrectKey = New(nil, http.StatusBadRequest, "incorrect decryption key") + ErrDecryptOperationCanceled = New(nil, http.StatusBadRequest, "decryption operation was canceled") + ErrNoMemoryRegionsFound = New(nil, http.StatusBadRequest, "no memory regions found") + ErrReadMemoryTimeout = New(nil, http.StatusInternalServerError, "read memory timeout") + ErrWeChatOffline = New(nil, http.StatusBadRequest, "WeChat is offline") + ErrSIPEnabled = New(nil, http.StatusBadRequest, "SIP is enabled") + ErrValidatorNotSet = New(nil, http.StatusBadRequest, "validator not set") + ErrNoValidKey = New(nil, http.StatusBadRequest, "no valid key found") + ErrWeChatDLLNotFound = New(nil, http.StatusBadRequest, "WeChatWin.dll module not found") +) + +func PlatformUnsupported(platform string, version int) *Error { + return Newf(nil, http.StatusBadRequest, "unsupported platform: %s v%d", platform, version).WithStack() +} + +func DecryptCreateCipherFailed(cause error) *Error { + return New(cause, http.StatusInternalServerError, "failed to create cipher").WithStack() +} + +func DecodeKeyFailed(cause error) *Error { + return New(cause, http.StatusBadRequest, "failed to decode hex key").WithStack() +} + +func CreatePipeFileFailed(cause error) *Error { + return New(cause, http.StatusInternalServerError, "failed to create pipe file").WithStack() +} + +func OpenPipeFileFailed(cause error) *Error { + return New(cause, http.StatusInternalServerError, "failed to open pipe file").WithStack() +} + +func ReadPipeFileFailed(cause error) *Error { + return New(cause, http.StatusInternalServerError, "failed to read from pipe file").WithStack() +} + +func RunCmdFailed(cause error) *Error { + return New(cause, http.StatusInternalServerError, "failed to run command").WithStack() +} + +func ReadMemoryFailed(cause error) *Error { + return New(cause, http.StatusInternalServerError, "failed to read memory").WithStack() +} + +func OpenProcessFailed(cause error) *Error { + return New(cause, http.StatusInternalServerError, "failed to open process").WithStack() +} + +func WeChatAccountNotFound(name string) *Error { + return Newf(nil, http.StatusBadRequest, "WeChat account not found: %s", name).WithStack() +} + +func WeChatAccountNotOnline(name string) *Error { + return Newf(nil, http.StatusBadRequest, "WeChat account is not online: %s", name).WithStack() +} + +func RefreshProcessStatusFailed(cause error) *Error { + return New(cause, http.StatusInternalServerError, "failed to refresh process status").WithStack() +} diff --git a/internal/errors/wechatdb_errors.go b/internal/errors/wechatdb_errors.go new file mode 100644 index 0000000..42b6745 --- /dev/null +++ b/internal/errors/wechatdb_errors.go @@ -0,0 +1,62 @@ +package errors + +import ( + "net/http" + "time" +) + +var ( + ErrTalkerEmpty = New(nil, http.StatusBadRequest, "talker empty").WithStack() + ErrKeyEmpty = New(nil, http.StatusBadRequest, "key empty").WithStack() + ErrMediaNotFound = New(nil, http.StatusNotFound, "media not found").WithStack() + ErrKeyLengthMust32 = New(nil, http.StatusBadRequest, "key length must be 32 bytes").WithStack() +) + +// 数据库初始化相关错误 +func DBFileNotFound(path, pattern string, cause error) *Error { + return Newf(cause, http.StatusNotFound, "db file not found %s: %s", path, pattern).WithStack() +} + +func DBConnectFailed(path string, cause error) *Error { + return Newf(cause, http.StatusInternalServerError, "db connect failed: %s", path).WithStack() +} + +func DBInitFailed(cause error) *Error { + return New(cause, http.StatusInternalServerError, "db init failed").WithStack() +} + +func TalkerNotFound(talker string) *Error { + return Newf(nil, http.StatusNotFound, "talker not found: %s", talker).WithStack() +} + +func DBCloseFailed(cause error) *Error { + return New(cause, http.StatusInternalServerError, "db close failed").WithStack() +} + +func QueryFailed(query string, cause error) *Error { + return Newf(cause, http.StatusInternalServerError, "query failed: %s", query).WithStack() +} + +func ScanRowFailed(cause error) *Error { + return New(cause, http.StatusInternalServerError, "scan row failed").WithStack() +} + +func TimeRangeNotFound(start, end time.Time) *Error { + return Newf(nil, http.StatusNotFound, "time range not found: %s - %s", start, end).WithStack() +} + +func MediaTypeUnsupported(_type string) *Error { + return Newf(nil, http.StatusBadRequest, "unsupported media type: %s", _type).WithStack() +} + +func ChatRoomNotFound(key string) *Error { + return Newf(nil, http.StatusNotFound, "chat room not found: %s", key).WithStack() +} + +func ContactNotFound(key string) *Error { + return Newf(nil, http.StatusNotFound, "contact not found: %s", key).WithStack() +} + +func InitCacheFailed(cause error) *Error { + return New(cause, http.StatusInternalServerError, "init cache failed").WithStack() +} diff --git a/internal/mcp/mcp.go b/internal/mcp/mcp.go index 79306f7..dab9522 100644 --- a/internal/mcp/mcp.go +++ b/internal/mcp/mcp.go @@ -7,7 +7,7 @@ import ( "github.com/gin-gonic/gin" "github.com/google/uuid" - log "github.com/sirupsen/logrus" + "github.com/rs/zerolog/log" ) const ( @@ -85,7 +85,7 @@ func (m *MCP) HandleMessages(c *gin.Context) { return } - log.Printf("收到消息: %v\n", req) + log.Debug().Msgf("session: %s, request: %s", sessionID, req) select { case m.ProcessChan <- ProcessCtx{Session: session, Request: &req}: default: diff --git a/internal/wechat/decrypt/common/common.go b/internal/wechat/decrypt/common/common.go index 03b173d..dd12d1e 100644 --- a/internal/wechat/decrypt/common/common.go +++ b/internal/wechat/decrypt/common/common.go @@ -32,13 +32,13 @@ type DBFile struct { func OpenDBFile(dbPath string, pageSize int) (*DBFile, error) { fp, err := os.Open(dbPath) if err != nil { - return nil, errors.DecryptOpenFileFailed(dbPath, err) + return nil, errors.OpenFileFailed(dbPath, err) } defer fp.Close() fileInfo, err := fp.Stat() if err != nil { - return nil, errors.WeChatDecryptFailed(err) + return nil, errors.StatFileFailed(dbPath, err) } fileSize := fileInfo.Size() @@ -50,10 +50,10 @@ func OpenDBFile(dbPath string, pageSize int) (*DBFile, error) { buffer := make([]byte, pageSize) n, err := io.ReadFull(fp, buffer) if err != nil { - return nil, errors.DecryptReadFileFailed(dbPath, err) + return nil, errors.ReadFileFailed(dbPath, err) } if n != pageSize { - return nil, errors.DecryptIncompleteRead(fmt.Errorf("read %d bytes, expected %d", n, pageSize)) + return nil, errors.IncompleteRead(fmt.Errorf("read %d bytes, expected %d", n, pageSize)) } if bytes.Equal(buffer[:len(SQLiteHeader)-1], []byte(SQLiteHeader[:len(SQLiteHeader)-1])) { diff --git a/internal/wechat/decrypt/darwin/v3.go b/internal/wechat/decrypt/darwin/v3.go index 67aaa6c..e60772c 100644 --- a/internal/wechat/decrypt/darwin/v3.go +++ b/internal/wechat/decrypt/darwin/v3.go @@ -10,6 +10,7 @@ import ( "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/wechat/decrypt/common" + "golang.org/x/crypto/pbkdf2" ) @@ -75,7 +76,7 @@ func (d *V3Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, // 解码密钥 key, err := hex.DecodeString(hexKey) if err != nil { - return errors.DecryptDecodeKeyFailed(err) + return errors.DecodeKeyFailed(err) } // 打开数据库文件并读取基本信息 @@ -95,14 +96,14 @@ func (d *V3Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, // 打开数据库文件 dbFile, err := os.Open(dbfile) if err != nil { - return errors.DecryptOpenFileFailed(dbfile, err) + return errors.OpenFileFailed(dbfile, err) } defer dbFile.Close() // 写入 SQLite 头 _, err = output.Write([]byte(common.SQLiteHeader)) if err != nil { - return errors.DecryptWriteOutputFailed(err) + return errors.WriteOutputFailed(err) } // 处理每一页 @@ -112,7 +113,7 @@ func (d *V3Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, // 检查是否取消 select { case <-ctx.Done(): - return errors.DecryptOperationCanceled() + return errors.ErrDecryptOperationCanceled default: // 继续处理 } @@ -126,7 +127,7 @@ func (d *V3Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, break } } - return errors.DecryptReadFileFailed(dbfile, err) + return errors.ReadFileFailed(dbfile, err) } // 检查页面是否全为零 @@ -142,7 +143,7 @@ func (d *V3Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, // 写入零页面 _, err = output.Write(pageBuf) if err != nil { - return errors.DecryptWriteOutputFailed(err) + return errors.WriteOutputFailed(err) } continue } @@ -156,7 +157,7 @@ func (d *V3Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, // 写入解密后的页面 _, err = output.Write(decryptedData) if err != nil { - return errors.DecryptWriteOutputFailed(err) + return errors.WriteOutputFailed(err) } } diff --git a/internal/wechat/decrypt/darwin/v4.go b/internal/wechat/decrypt/darwin/v4.go index 13d7380..67e8528 100644 --- a/internal/wechat/decrypt/darwin/v4.go +++ b/internal/wechat/decrypt/darwin/v4.go @@ -80,7 +80,7 @@ func (d *V4Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, // 解码密钥 key, err := hex.DecodeString(hexKey) if err != nil { - return errors.DecryptDecodeKeyFailed(err) + return errors.DecodeKeyFailed(err) } // 打开数据库文件并读取基本信息 @@ -100,14 +100,14 @@ func (d *V4Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, // 打开数据库文件 dbFile, err := os.Open(dbfile) if err != nil { - return errors.DecryptOpenFileFailed(dbfile, err) + return errors.OpenFileFailed(dbfile, err) } defer dbFile.Close() // 写入SQLite头 _, err = output.Write([]byte(common.SQLiteHeader)) if err != nil { - return errors.DecryptWriteOutputFailed(err) + return errors.WriteOutputFailed(err) } // 处理每一页 @@ -117,7 +117,7 @@ func (d *V4Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, // 检查是否取消 select { case <-ctx.Done(): - return errors.DecryptOperationCanceled() + return errors.ErrDecryptOperationCanceled default: // 继续处理 } @@ -131,7 +131,7 @@ func (d *V4Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, break } } - return errors.DecryptReadFileFailed(dbfile, err) + return errors.ReadFileFailed(dbfile, err) } // 检查页面是否全为零 @@ -147,7 +147,7 @@ func (d *V4Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, // 写入零页面 _, err = output.Write(pageBuf) if err != nil { - return errors.DecryptWriteOutputFailed(err) + return errors.WriteOutputFailed(err) } continue } @@ -161,7 +161,7 @@ func (d *V4Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, // 写入解密后的页面 _, err = output.Write(decryptedData) if err != nil { - return errors.DecryptWriteOutputFailed(err) + return errors.WriteOutputFailed(err) } } diff --git a/internal/wechat/decrypt/decryptor.go b/internal/wechat/decrypt/decryptor.go index a48eb2f..97a8e35 100644 --- a/internal/wechat/decrypt/decryptor.go +++ b/internal/wechat/decrypt/decryptor.go @@ -2,7 +2,6 @@ package decrypt import ( "context" - "fmt" "io" "github.com/sjzar/chatlog/internal/errors" @@ -10,12 +9,6 @@ import ( "github.com/sjzar/chatlog/internal/wechat/decrypt/windows" ) -// 错误定义 -var ( - ErrInvalidVersion = fmt.Errorf("invalid version, must be 3 or 4") - ErrUnsupportedPlatform = fmt.Errorf("unsupported platform") -) - // Decryptor 定义数据库解密的接口 type Decryptor interface { // Decrypt 解密数据库 diff --git a/internal/wechat/decrypt/windows/v3.go b/internal/wechat/decrypt/windows/v3.go index 07f76ef..af50872 100644 --- a/internal/wechat/decrypt/windows/v3.go +++ b/internal/wechat/decrypt/windows/v3.go @@ -78,7 +78,7 @@ func (d *V3Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, // 解码密钥 key, err := hex.DecodeString(hexKey) if err != nil { - return errors.DecryptDecodeKeyFailed(err) + return errors.DecodeKeyFailed(err) } // 打开数据库文件并读取基本信息 @@ -98,14 +98,14 @@ func (d *V3Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, // 打开数据库文件 dbFile, err := os.Open(dbfile) if err != nil { - return errors.DecryptOpenFileFailed(dbfile, err) + return errors.OpenFileFailed(dbfile, err) } defer dbFile.Close() // 写入SQLite头 _, err = output.Write([]byte(common.SQLiteHeader)) if err != nil { - return errors.DecryptWriteOutputFailed(err) + return errors.WriteOutputFailed(err) } // 处理每一页 @@ -115,7 +115,7 @@ func (d *V3Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, // 检查是否取消 select { case <-ctx.Done(): - return errors.DecryptOperationCanceled() + return errors.ErrDecryptOperationCanceled default: // 继续处理 } @@ -129,7 +129,7 @@ func (d *V3Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, break } } - return errors.DecryptReadFileFailed(dbfile, err) + return errors.ReadFileFailed(dbfile, err) } // 检查页面是否全为零 @@ -145,7 +145,7 @@ func (d *V3Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, // 写入零页面 _, err = output.Write(pageBuf) if err != nil { - return errors.DecryptWriteOutputFailed(err) + return errors.WriteOutputFailed(err) } continue } @@ -159,7 +159,7 @@ func (d *V3Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, // 写入解密后的页面 _, err = output.Write(decryptedData) if err != nil { - return errors.DecryptWriteOutputFailed(err) + return errors.WriteOutputFailed(err) } } diff --git a/internal/wechat/decrypt/windows/v4.go b/internal/wechat/decrypt/windows/v4.go index bd2e519..4585010 100644 --- a/internal/wechat/decrypt/windows/v4.go +++ b/internal/wechat/decrypt/windows/v4.go @@ -10,6 +10,7 @@ import ( "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/wechat/decrypt/common" + "golang.org/x/crypto/pbkdf2" ) @@ -76,7 +77,7 @@ func (d *V4Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, // 解码密钥 key, err := hex.DecodeString(hexKey) if err != nil { - return errors.DecryptDecodeKeyFailed(err) + return errors.DecodeKeyFailed(err) } // 打开数据库文件并读取基本信息 @@ -96,14 +97,14 @@ func (d *V4Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, // 打开数据库文件 dbFile, err := os.Open(dbfile) if err != nil { - return errors.DecryptOpenFileFailed(dbfile, err) + return errors.OpenFileFailed(dbfile, err) } defer dbFile.Close() // 写入SQLite头 _, err = output.Write([]byte(common.SQLiteHeader)) if err != nil { - return errors.DecryptWriteOutputFailed(err) + return errors.WriteOutputFailed(err) } // 处理每一页 @@ -113,7 +114,7 @@ func (d *V4Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, // 检查是否取消 select { case <-ctx.Done(): - return errors.DecryptOperationCanceled() + return errors.ErrDecryptOperationCanceled default: // 继续处理 } @@ -127,7 +128,7 @@ func (d *V4Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, break } } - return errors.DecryptReadFileFailed(dbfile, err) + return errors.ReadFileFailed(dbfile, err) } // 检查页面是否全为零 @@ -143,7 +144,7 @@ func (d *V4Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, // 写入零页面 _, err = output.Write(pageBuf) if err != nil { - return errors.DecryptWriteOutputFailed(err) + return errors.WriteOutputFailed(err) } continue } @@ -157,7 +158,7 @@ func (d *V4Decryptor) Decrypt(ctx context.Context, dbfile string, hexKey string, // 写入解密后的页面 _, err = output.Write(decryptedData) if err != nil { - return errors.DecryptWriteOutputFailed(err) + return errors.WriteOutputFailed(err) } } diff --git a/internal/wechat/key/darwin/glance/glance.go b/internal/wechat/key/darwin/glance/glance.go index 7bf0c6d..714f88d 100644 --- a/internal/wechat/key/darwin/glance/glance.go +++ b/internal/wechat/key/darwin/glance/glance.go @@ -8,6 +8,9 @@ import ( "os/exec" "path/filepath" "time" + + "github.com/rs/zerolog/log" + "github.com/sjzar/chatlog/internal/errors" ) // FIXME 按照 region 读取效率较低,512MB 内存读取耗时约 18s @@ -38,14 +41,14 @@ func (g *Glance) Read() ([]byte, error) { g.MemRegions = MemRegionsFilter(regions) if len(g.MemRegions) == 0 { - return nil, fmt.Errorf("no memory regions found") + return nil, errors.ErrNoMemoryRegionsFound } region := g.MemRegions[0] // 1. Create pipe file if err := exec.Command("mkfifo", g.pipePath).Run(); err != nil { - return nil, fmt.Errorf("failed to create pipe file: %w", err) + return nil, errors.CreatePipeFileFailed(err) } defer os.Remove(g.pipePath) @@ -56,7 +59,7 @@ func (g *Glance) Read() ([]byte, error) { // Open pipe for reading file, err := os.OpenFile(g.pipePath, os.O_RDONLY, 0600) if err != nil { - errCh <- fmt.Errorf("failed to open pipe for reading: %w", err) + errCh <- errors.OpenPipeFileFailed(err) return } defer file.Close() @@ -64,7 +67,7 @@ func (g *Glance) Read() ([]byte, error) { // Read all data from pipe data, err := io.ReadAll(file) if err != nil { - errCh <- fmt.Errorf("failed to read from pipe: %w", err) + errCh <- errors.ReadPipeFileFailed(err) return } dataCh <- data @@ -80,12 +83,12 @@ func (g *Glance) Read() ([]byte, error) { // Set up stdout pipe for monitoring (optional) stdout, err := cmd.StdoutPipe() if err != nil { - return nil, fmt.Errorf("failed to create stdout pipe: %w", err) + return nil, err } // Start the command if err := cmd.Start(); err != nil { - return nil, fmt.Errorf("failed to start lldb: %w", err) + return nil, errors.RunCmdFailed(err) } // Monitor lldb output (optional) @@ -102,16 +105,16 @@ func (g *Glance) Read() ([]byte, error) { case data := <-dataCh: g.data = data case err := <-errCh: - return nil, fmt.Errorf("failed to read memory: %w", err) + return nil, errors.ReadMemoryFailed(err) case <-time.After(30 * time.Second): cmd.Process.Kill() - return nil, fmt.Errorf("timeout waiting for memory data") + return nil, errors.ErrReadMemoryTimeout } // Wait for the command to finish if err := cmd.Wait(); err != nil { // We already have the data, so just log the error - fmt.Printf("Warning: lldb process exited with error: %v\n", err) + log.Err(err).Msg("lldb process exited with error") } return g.data, nil diff --git a/internal/wechat/key/darwin/glance/vmmap.go b/internal/wechat/key/darwin/glance/vmmap.go index d87b18b..ca1754f 100644 --- a/internal/wechat/key/darwin/glance/vmmap.go +++ b/internal/wechat/key/darwin/glance/vmmap.go @@ -7,6 +7,8 @@ import ( "regexp" "strconv" "strings" + + "github.com/sjzar/chatlog/internal/errors" ) const ( @@ -31,7 +33,7 @@ func GetVmmap(pid uint32) ([]MemRegion, error) { cmd := exec.Command(CommandVmmap, "-wide", fmt.Sprintf("%d", pid)) output, err := cmd.CombinedOutput() if err != nil { - return nil, fmt.Errorf("error executing vmmap command: %w", err) + return nil, errors.RunCmdFailed(err) } // Parse the output using the existing LoadVmmap function diff --git a/internal/wechat/key/darwin/v3.go b/internal/wechat/key/darwin/v3.go index 46e0120..8647b04 100644 --- a/internal/wechat/key/darwin/v3.go +++ b/internal/wechat/key/darwin/v3.go @@ -4,12 +4,12 @@ import ( "bytes" "context" "encoding/hex" - "fmt" "runtime" "sync" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog/log" + "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/wechat/decrypt" "github.com/sjzar/chatlog/internal/wechat/key/darwin/glance" "github.com/sjzar/chatlog/internal/wechat/model" @@ -29,16 +29,16 @@ func NewV3Extractor() *V3Extractor { func (e *V3Extractor) Extract(ctx context.Context, proc *model.Process) (string, error) { if proc.Status == model.StatusOffline { - return "", fmt.Errorf("WeChat is offline") + return "", errors.ErrWeChatOffline } // Check if SIP is disabled, as it's required for memory reading on macOS if !glance.IsSIPDisabled() { - return "", fmt.Errorf("System Integrity Protection (SIP) is enabled, cannot read process memory") + return "", errors.ErrSIPEnabled } if e.validator == nil { - return "", fmt.Errorf("validator not set") + return "", errors.ErrValidatorNotSet } // Create context to control all goroutines @@ -57,7 +57,7 @@ func (e *V3Extractor) Extract(ctx context.Context, proc *model.Process) (string, if workerCount > MaxWorkersV3 { workerCount = MaxWorkersV3 } - logrus.Debug("Starting ", workerCount, " workers for V3 key search") + log.Debug().Msgf("Starting %d workers for V3 key search", workerCount) // Start consumer goroutines var workerWaitGroup sync.WaitGroup @@ -77,7 +77,7 @@ func (e *V3Extractor) Extract(ctx context.Context, proc *model.Process) (string, defer close(memoryChannel) // Close channel when producer is done err := e.findMemory(searchCtx, uint32(proc.PID), memoryChannel) if err != nil { - logrus.Error(err) + log.Err(err).Msg("Failed to read memory") } }() @@ -98,7 +98,7 @@ func (e *V3Extractor) Extract(ctx context.Context, proc *model.Process) (string, } } - return "", fmt.Errorf("no valid key found") + return "", errors.ErrNoValidKey } // findMemory searches for memory regions using Glance @@ -109,15 +109,15 @@ func (e *V3Extractor) findMemory(ctx context.Context, pid uint32, memoryChannel // Read memory data memory, err := g.Read() if err != nil { - return fmt.Errorf("failed to read process memory: %w", err) + return err } - logrus.Debug("Read memory region, size: ", len(memory), " bytes") + log.Debug().Msgf("Read memory region, size: %d bytes", len(memory)) // Send memory data to channel for processing select { case memoryChannel <- memory: - logrus.Debug("Sent memory region for analysis") + log.Debug().Msg("Memory region sent for analysis") case <-ctx.Done(): return ctx.Err() } @@ -146,7 +146,7 @@ func (e *V3Extractor) worker(ctx context.Context, memoryChannel <-chan []byte, r default: } - logrus.Debugf("Searching for V3 key in memory region, size: %d bytes", len(memory)) + log.Debug().Msgf("Searching for V3 key in memory region, size: %d bytes", len(memory)) // Find pattern from end to beginning index = bytes.LastIndex(memory[:index], keyPattern) @@ -154,7 +154,7 @@ func (e *V3Extractor) worker(ctx context.Context, memoryChannel <-chan []byte, r break // No more matches found } - logrus.Debugf("Found potential V3 key pattern in memory region, index: %d", index) + log.Debug().Msgf("Found potential V3 key pattern in memory region, index: %d", index) // For V3, the key is 32 bytes and starts right after the pattern if index+24+32 > len(memory) { @@ -170,7 +170,7 @@ func (e *V3Extractor) worker(ctx context.Context, memoryChannel <-chan []byte, r if e.validator.Validate(keyData) { select { case resultChannel <- hex.EncodeToString(keyData): - logrus.Debug("Valid key found for V3 database") + log.Debug().Msg("Key found: " + hex.EncodeToString(keyData)) return default: } diff --git a/internal/wechat/key/darwin/v4.go b/internal/wechat/key/darwin/v4.go index b1f8534..c551dcb 100644 --- a/internal/wechat/key/darwin/v4.go +++ b/internal/wechat/key/darwin/v4.go @@ -4,12 +4,12 @@ import ( "bytes" "context" "encoding/hex" - "fmt" "runtime" "sync" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog/log" + "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/wechat/decrypt" "github.com/sjzar/chatlog/internal/wechat/key/darwin/glance" "github.com/sjzar/chatlog/internal/wechat/model" @@ -29,16 +29,16 @@ func NewV4Extractor() *V4Extractor { func (e *V4Extractor) Extract(ctx context.Context, proc *model.Process) (string, error) { if proc.Status == model.StatusOffline { - return "", fmt.Errorf("WeChat is offline") + return "", errors.ErrWeChatOffline } // Check if SIP is disabled, as it's required for memory reading on macOS if !glance.IsSIPDisabled() { - return "", fmt.Errorf("System Integrity Protection (SIP) is enabled, cannot read process memory") + return "", errors.ErrSIPEnabled } if e.validator == nil { - return "", fmt.Errorf("validator not set") + return "", errors.ErrValidatorNotSet } // Create context to control all goroutines @@ -57,7 +57,7 @@ func (e *V4Extractor) Extract(ctx context.Context, proc *model.Process) (string, if workerCount > MaxWorkers { workerCount = MaxWorkers } - logrus.Debug("Starting ", workerCount, " workers for V4 key search") + log.Debug().Msgf("Starting %d workers for V4 key search", workerCount) // Start consumer goroutines var workerWaitGroup sync.WaitGroup @@ -77,7 +77,7 @@ func (e *V4Extractor) Extract(ctx context.Context, proc *model.Process) (string, defer close(memoryChannel) // Close channel when producer is done err := e.findMemory(searchCtx, uint32(proc.PID), memoryChannel) if err != nil { - logrus.Error(err) + log.Err(err).Msg("Failed to read memory") } }() @@ -98,7 +98,7 @@ func (e *V4Extractor) Extract(ctx context.Context, proc *model.Process) (string, } } - return "", fmt.Errorf("no valid key found") + return "", errors.ErrNoValidKey } // findMemory searches for memory regions using Glance @@ -109,15 +109,15 @@ func (e *V4Extractor) findMemory(ctx context.Context, pid uint32, memoryChannel // Read memory data memory, err := g.Read() if err != nil { - return fmt.Errorf("failed to read process memory: %w", err) + return err } - logrus.Debug("Read memory region, size: ", len(memory), " bytes") + log.Debug().Msgf("Read memory region, size: %d bytes", len(memory)) // Send memory data to channel for processing select { case memoryChannel <- memory: - logrus.Debug("Sent memory region for analysis") + log.Debug().Msg("Memory region sent for analysis") case <-ctx.Done(): return ctx.Err() } @@ -167,7 +167,7 @@ func (e *V4Extractor) worker(ctx context.Context, memoryChannel <-chan []byte, r if e.validator.Validate(keyData) { select { case resultChannel <- hex.EncodeToString(keyData): - logrus.Debug("Valid key found for V4 database") + log.Debug().Msg("Key found: " + hex.EncodeToString(keyData)) return default: } diff --git a/internal/wechat/key/extractor.go b/internal/wechat/key/extractor.go index 7586658..0873aef 100644 --- a/internal/wechat/key/extractor.go +++ b/internal/wechat/key/extractor.go @@ -2,20 +2,14 @@ package key import ( "context" - "fmt" + "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/wechat/decrypt" "github.com/sjzar/chatlog/internal/wechat/key/darwin" "github.com/sjzar/chatlog/internal/wechat/key/windows" "github.com/sjzar/chatlog/internal/wechat/model" ) -// 错误定义 -var ( - ErrInvalidVersion = fmt.Errorf("invalid version, must be 3 or 4") - ErrUnsupportedPlatform = fmt.Errorf("unsupported platform") -) - // Extractor 定义密钥提取器接口 type Extractor interface { // Extract 从进程中提取密钥 @@ -36,6 +30,6 @@ func NewExtractor(platform string, version int) (Extractor, error) { case platform == "darwin" && version == 4: return darwin.NewV4Extractor(), nil default: - return nil, fmt.Errorf("%w: %s v%d", ErrUnsupportedPlatform, platform, version) + return nil, errors.PlatformUnsupported(platform, version) } } diff --git a/internal/wechat/key/windows/v3.go b/internal/wechat/key/windows/v3.go index 0034a24..7567e80 100644 --- a/internal/wechat/key/windows/v3.go +++ b/internal/wechat/key/windows/v3.go @@ -1,20 +1,9 @@ package windows import ( - "errors" - "github.com/sjzar/chatlog/internal/wechat/decrypt" ) -// Common error definitions -var ( - ErrWeChatOffline = errors.New("wechat is not logged in") - ErrOpenProcess = errors.New("failed to open process") - ErrCheckProcessBits = errors.New("failed to check process architecture") - ErrFindWeChatDLL = errors.New("WeChatWin.dll module not found") - ErrNoValidKey = errors.New("no valid key found") -) - type V3Extractor struct { validator *decrypt.Validator } diff --git a/internal/wechat/key/windows/v3_windows.go b/internal/wechat/key/windows/v3_windows.go index 6f56616..97f9883 100644 --- a/internal/wechat/key/windows/v3_windows.go +++ b/internal/wechat/key/windows/v3_windows.go @@ -10,9 +10,10 @@ import ( "sync" "unsafe" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog/log" "golang.org/x/sys/windows" + "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/wechat/model" "github.com/sjzar/chatlog/pkg/util" ) @@ -24,20 +25,20 @@ const ( func (e *V3Extractor) Extract(ctx context.Context, proc *model.Process) (string, error) { if proc.Status == model.StatusOffline { - return "", ErrWeChatOffline + return "", errors.ErrWeChatOffline } // Open WeChat process handle, err := windows.OpenProcess(windows.PROCESS_QUERY_INFORMATION|windows.PROCESS_VM_READ, false, proc.PID) if err != nil { - return "", fmt.Errorf("%w: %v", ErrOpenProcess, err) + return "", errors.OpenProcessFailed(err) } defer windows.CloseHandle(handle) // Check process architecture is64Bit, err := util.Is64Bit(handle) if err != nil { - return "", fmt.Errorf("%w: %v", ErrCheckProcessBits, err) + return "", err } // Create context to control all goroutines @@ -56,7 +57,7 @@ func (e *V3Extractor) Extract(ctx context.Context, proc *model.Process) (string, if workerCount > MaxWorkers { workerCount = MaxWorkers } - logrus.Debug("Starting ", workerCount, " workers for V3 key search") + log.Debug().Msgf("Starting %d workers for V3 key search", workerCount) // Start consumer goroutines var workerWaitGroup sync.WaitGroup @@ -76,7 +77,7 @@ func (e *V3Extractor) Extract(ctx context.Context, proc *model.Process) (string, defer close(memoryChannel) // Close channel when producer is done err := e.findMemory(searchCtx, handle, proc.PID, memoryChannel) if err != nil { - logrus.Error(err) + log.Err(err).Msg("Failed to find memory regions") } }() @@ -97,7 +98,7 @@ func (e *V3Extractor) Extract(ctx context.Context, proc *model.Process) (string, } } - return "", ErrNoValidKey + return "", errors.ErrNoValidKey } // findMemoryV3 searches for writable memory regions in WeChatWin.dll for V3 version @@ -105,9 +106,9 @@ func (e *V3Extractor) findMemory(ctx context.Context, handle windows.Handle, pid // Find WeChatWin.dll module module, isFound := FindModule(pid, V3ModuleName) if !isFound { - return ErrFindWeChatDLL + return errors.ErrWeChatDLLNotFound } - logrus.Debug("Found WeChatWin.dll module at base address: 0x", fmt.Sprintf("%X", module.ModBaseAddr)) + log.Debug().Msg("Found WeChatWin.dll module at base address: 0x" + fmt.Sprintf("%X", module.ModBaseAddr)) // Read writable memory regions baseAddr := uintptr(module.ModBaseAddr) @@ -141,7 +142,7 @@ func (e *V3Extractor) findMemory(ctx context.Context, handle windows.Handle, pid if err = windows.ReadProcessMemory(handle, currentAddr, &memory[0], regionSize, nil); err == nil { select { case memoryChannel <- memory: - logrus.Debug("Sent memory region for analysis, size: ", regionSize, " bytes") + log.Debug().Msgf("Memory region: 0x%X - 0x%X, size: %d bytes", currentAddr, currentAddr+regionSize, regionSize) case <-ctx.Done(): return nil } @@ -198,7 +199,7 @@ func (e *V3Extractor) worker(ctx context.Context, handle windows.Handle, is64Bit if key := e.validateKey(handle, ptrValue); key != "" { select { case resultChannel <- key: - logrus.Debug("Valid key found for V3 database") + log.Debug().Msg("Valid key found: " + key) return default: } @@ -230,7 +231,7 @@ func FindModule(pid uint32, name string) (module windows.ModuleEntry32, isFound // Create module snapshot snapshot, err := windows.CreateToolhelp32Snapshot(windows.TH32CS_SNAPMODULE|windows.TH32CS_SNAPMODULE32, pid) if err != nil { - logrus.Debug("Failed to create module snapshot: ", err) + log.Debug().Msgf("Failed to create module snapshot for PID %d: %v", pid, err) return module, false } defer windows.CloseHandle(snapshot) @@ -240,7 +241,7 @@ func FindModule(pid uint32, name string) (module windows.ModuleEntry32, isFound // Get the first module if err := windows.Module32First(snapshot, &module); err != nil { - logrus.Debug("Failed to get first module: ", err) + log.Debug().Msgf("Module32First failed for PID %d: %v", pid, err) return module, false } diff --git a/internal/wechat/key/windows/v4_windows.go b/internal/wechat/key/windows/v4_windows.go index 5eafaa4..4f14771 100644 --- a/internal/wechat/key/windows/v4_windows.go +++ b/internal/wechat/key/windows/v4_windows.go @@ -5,14 +5,14 @@ import ( "context" "encoding/binary" "encoding/hex" - "fmt" "runtime" "sync" "unsafe" - "github.com/sirupsen/logrus" + "github.com/rs/zerolog/log" "golang.org/x/sys/windows" + "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/wechat/model" ) @@ -22,13 +22,13 @@ const ( func (e *V4Extractor) Extract(ctx context.Context, proc *model.Process) (string, error) { if proc.Status == model.StatusOffline { - return "", ErrWeChatOffline + return "", errors.ErrWeChatOffline } // Open process handle handle, err := windows.OpenProcess(windows.PROCESS_VM_READ|windows.PROCESS_QUERY_INFORMATION, false, proc.PID) if err != nil { - return "", fmt.Errorf("%w: %v", ErrOpenProcess, err) + return "", errors.OpenProcessFailed(err) } defer windows.CloseHandle(handle) @@ -48,7 +48,7 @@ func (e *V4Extractor) Extract(ctx context.Context, proc *model.Process) (string, if workerCount > MaxWorkers { workerCount = MaxWorkers } - logrus.Debug("Starting ", workerCount, " workers for V4 key search") + log.Debug().Msgf("Starting %d workers for V4 key search", workerCount) // Start consumer goroutines var workerWaitGroup sync.WaitGroup @@ -68,7 +68,7 @@ func (e *V4Extractor) Extract(ctx context.Context, proc *model.Process) (string, defer close(memoryChannel) // Close channel when producer is done err := e.findMemory(searchCtx, handle, memoryChannel) if err != nil { - logrus.Error(err) + log.Err(err).Msg("Failed to find memory regions") } }() @@ -89,7 +89,7 @@ func (e *V4Extractor) Extract(ctx context.Context, proc *model.Process) (string, } } - return "", ErrNoValidKey + return "", errors.ErrNoValidKey } // findMemoryV4 searches for writable memory regions for V4 version @@ -101,7 +101,7 @@ func (e *V4Extractor) findMemory(ctx context.Context, handle windows.Handle, mem if runtime.GOARCH == "amd64" { maxAddr = uintptr(0x7FFFFFFFFFFF) // 64-bit process space limit } - logrus.Debug("Scanning memory regions from 0x", fmt.Sprintf("%X", minAddr), " to 0x", fmt.Sprintf("%X", maxAddr)) + log.Debug().Msgf("Scanning memory regions from 0x%X to 0x%X", minAddr, maxAddr) currentAddr := minAddr @@ -131,7 +131,7 @@ func (e *V4Extractor) findMemory(ctx context.Context, handle windows.Handle, mem if err = windows.ReadProcessMemory(handle, currentAddr, &memory[0], regionSize, nil); err == nil { select { case memoryChannel <- memory: - logrus.Debug("Sent memory region for analysis, size: ", regionSize, " bytes") + log.Debug().Msgf("Memory region for analysis: 0x%X - 0x%X, size: %d bytes", currentAddr, currentAddr+regionSize, regionSize) case <-ctx.Done(): return nil } @@ -185,7 +185,7 @@ func (e *V4Extractor) worker(ctx context.Context, handle windows.Handle, memoryC if key := e.validateKey(handle, ptrValue); key != "" { select { case resultChannel <- key: - logrus.Debug("Valid key found for V4 database") + log.Debug().Msg("Valid key found: " + key) return default: } diff --git a/internal/wechat/manager.go b/internal/wechat/manager.go index dd6bd42..b317f35 100644 --- a/internal/wechat/manager.go +++ b/internal/wechat/manager.go @@ -2,9 +2,9 @@ package wechat import ( "context" - "fmt" "runtime" + "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/wechat/model" "github.com/sjzar/chatlog/internal/wechat/process" ) @@ -87,7 +87,7 @@ func (m *Manager) GetAccount(name string) (*Account, error) { func (m *Manager) GetProcess(name string) (*model.Process, error) { p, ok := m.processMap[name] if !ok { - return nil, fmt.Errorf("account not found: %s", name) + return nil, errors.WeChatAccountNotFound(name) } return p, nil } diff --git a/internal/wechat/process/darwin/detector.go b/internal/wechat/process/darwin/detector.go index 68a3294..d3a978b 100644 --- a/internal/wechat/process/darwin/detector.go +++ b/internal/wechat/process/darwin/detector.go @@ -1,15 +1,15 @@ package darwin import ( - "fmt" "os/exec" "path/filepath" "strconv" "strings" + "github.com/rs/zerolog/log" "github.com/shirou/gopsutil/v4/process" - log "github.com/sirupsen/logrus" + "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/wechat/model" "github.com/sjzar/chatlog/pkg/appver" ) @@ -33,7 +33,7 @@ func NewDetector() *Detector { func (d *Detector) FindProcesses() ([]*model.Process, error) { processes, err := process.Processes() if err != nil { - log.Errorf("获取进程列表失败: %v", err) + log.Err(err).Msg("获取进程列表失败") return nil, err } @@ -47,7 +47,7 @@ func (d *Detector) FindProcesses() ([]*model.Process, error) { // 获取进程信息 procInfo, err := d.getProcessInfo(p) if err != nil { - log.Errorf("获取进程 %d 的信息失败: %v", p.Pid, err) + log.Err(err).Msgf("获取进程 %d 的信息失败", p.Pid) continue } @@ -68,7 +68,7 @@ func (d *Detector) getProcessInfo(p *process.Process) (*model.Process, error) { // 获取可执行文件路径 exePath, err := p.Exe() if err != nil { - log.Error(err) + log.Err(err).Msg("获取可执行文件路径失败") return nil, err } procInfo.ExePath = exePath @@ -77,7 +77,7 @@ func (d *Detector) getProcessInfo(p *process.Process) (*model.Process, error) { // 注意:macOS 的版本获取方式可能与 Windows 不同 versionInfo, err := appver.New(exePath) if err != nil { - log.Error(err) + log.Err(err).Msg("获取版本信息失败") procInfo.Version = 3 procInfo.FullVersion = "3.0.0" } else { @@ -87,7 +87,7 @@ func (d *Detector) getProcessInfo(p *process.Process) (*model.Process, error) { // 初始化附加信息(数据目录、账户名) if err := d.initializeProcessInfo(p, procInfo); err != nil { - log.Errorf("初始化进程信息失败: %v", err) + log.Err(err).Msg("初始化进程信息失败") // 即使初始化失败也返回部分信息 } @@ -99,7 +99,7 @@ func (d *Detector) initializeProcessInfo(p *process.Process, info *model.Process // 使用 lsof 命令获取进程打开的文件 files, err := d.getOpenFiles(int(p.Pid)) if err != nil { - log.Error("获取打开文件列表失败: ", err) + log.Err(err).Msg("获取打开的文件失败") return err } @@ -112,7 +112,7 @@ func (d *Detector) initializeProcessInfo(p *process.Process, info *model.Process if strings.Contains(filePath, dbPath) { parts := strings.Split(filePath, string(filepath.Separator)) if len(parts) < 4 { - log.Debug("无效的文件路径格式: " + filePath) + log.Debug().Msg("无效的文件路径格式: " + filePath) continue } @@ -142,7 +142,7 @@ func (d *Detector) getOpenFiles(pid int) ([]string, error) { cmd := exec.Command("lsof", "-p", strconv.Itoa(pid), "-F", "n") output, err := cmd.Output() if err != nil { - return nil, fmt.Errorf("执行 lsof 命令失败: %v", err) + return nil, errors.RunCmdFailed(err) } // 解析 lsof -F n 输出 diff --git a/internal/wechat/process/windows/detector.go b/internal/wechat/process/windows/detector.go index 88b304a..ccec205 100644 --- a/internal/wechat/process/windows/detector.go +++ b/internal/wechat/process/windows/detector.go @@ -3,8 +3,8 @@ package windows import ( "strings" + "github.com/rs/zerolog/log" "github.com/shirou/gopsutil/v4/process" - log "github.com/sirupsen/logrus" "github.com/sjzar/chatlog/internal/wechat/model" "github.com/sjzar/chatlog/pkg/appver" @@ -29,7 +29,7 @@ func NewDetector() *Detector { func (d *Detector) FindProcesses() ([]*model.Process, error) { processes, err := process.Processes() if err != nil { - log.Errorf("获取进程列表失败: %v", err) + log.Err(err).Msg("获取进程列表失败") return nil, err } @@ -45,7 +45,7 @@ func (d *Detector) FindProcesses() ([]*model.Process, error) { if name == V4ProcessName { cmdline, err := p.Cmdline() if err != nil { - log.Error(err) + log.Err(err).Msg("获取进程命令行失败") continue } if strings.Contains(cmdline, "--") { @@ -56,7 +56,7 @@ func (d *Detector) FindProcesses() ([]*model.Process, error) { // 获取进程信息 procInfo, err := d.getProcessInfo(p) if err != nil { - log.Errorf("获取进程 %d 的信息失败: %v", p.Pid, err) + log.Err(err).Msgf("获取进程 %d 的信息失败", p.Pid) continue } @@ -77,7 +77,7 @@ func (d *Detector) getProcessInfo(p *process.Process) (*model.Process, error) { // 获取可执行文件路径 exePath, err := p.Exe() if err != nil { - log.Error(err) + log.Err(err).Msg("获取可执行文件路径失败") return nil, err } procInfo.ExePath = exePath @@ -85,7 +85,7 @@ func (d *Detector) getProcessInfo(p *process.Process) (*model.Process, error) { // 获取版本信息 versionInfo, err := appver.New(exePath) if err != nil { - log.Error(err) + log.Err(err).Msg("获取版本信息失败") return nil, err } procInfo.Version = versionInfo.Version @@ -93,7 +93,7 @@ func (d *Detector) getProcessInfo(p *process.Process) (*model.Process, error) { // 初始化附加信息(数据目录、账户名) if err := initializeProcessInfo(p, procInfo); err != nil { - log.Errorf("初始化进程信息失败: %v", err) + log.Err(err).Msg("初始化进程信息失败") // 即使初始化失败也返回部分信息 } diff --git a/internal/wechat/process/windows/detector_windows.go b/internal/wechat/process/windows/detector_windows.go index 80bf9a3..f5cc1a3 100644 --- a/internal/wechat/process/windows/detector_windows.go +++ b/internal/wechat/process/windows/detector_windows.go @@ -4,8 +4,8 @@ import ( "path/filepath" "strings" + "github.com/rs/zerolog/log" "github.com/shirou/gopsutil/v4/process" - log "github.com/sirupsen/logrus" "github.com/sjzar/chatlog/internal/wechat/model" ) @@ -14,7 +14,7 @@ import ( func initializeProcessInfo(p *process.Process, info *model.Process) error { files, err := p.OpenFiles() if err != nil { - log.Error("获取打开文件列表失败: ", err) + log.Err(err).Msgf("获取进程 %d 的打开文件失败", p.Pid) return err } @@ -28,7 +28,7 @@ func initializeProcessInfo(p *process.Process, info *model.Process) error { filePath := f.Path[4:] // 移除 "\\?\" 前缀 parts := strings.Split(filePath, string(filepath.Separator)) if len(parts) < 4 { - log.Debug("无效的文件路径格式: " + filePath) + log.Debug().Msg("无效的文件路径: " + filePath) continue } diff --git a/internal/wechat/wechat.go b/internal/wechat/wechat.go index 9b22333..140c6d0 100644 --- a/internal/wechat/wechat.go +++ b/internal/wechat/wechat.go @@ -2,9 +2,9 @@ package wechat import ( "context" - "fmt" "os" + "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/wechat/decrypt" "github.com/sjzar/chatlog/internal/wechat/key" "github.com/sjzar/chatlog/internal/wechat/model" @@ -71,28 +71,28 @@ func (a *Account) GetKey(ctx context.Context) (string, error) { // 刷新进程状态 if err := a.RefreshStatus(); err != nil { - return "", fmt.Errorf("failed to refresh process status: %w", err) + return "", errors.RefreshProcessStatusFailed(err) } // 检查账号状态 if a.Status != model.StatusOnline { - return "", fmt.Errorf("account %s is not online", a.Name) + return "", errors.WeChatAccountNotOnline(a.Name) } // 创建密钥提取器 - 使用新的接口,传入平台和版本信息 extractor, err := key.NewExtractor(a.Platform, a.Version) if err != nil { - return "", fmt.Errorf("failed to create key extractor: %w", err) + return "", err } process, err := GetProcess(a.Name) if err != nil { - return "", fmt.Errorf("failed to get process: %w", err) + return "", err } validator, err := decrypt.NewValidator(process.DataDir, process.Platform, process.Version) if err != nil { - return "", fmt.Errorf("failed to create validator: %w", err) + return "", err } extractor.SetValidate(validator) diff --git a/internal/wechatdb/datasource/darwinv3/datasource.go b/internal/wechatdb/datasource/darwinv3/datasource.go index e7038a5..e754e39 100644 --- a/internal/wechatdb/datasource/darwinv3/datasource.go +++ b/internal/wechatdb/datasource/darwinv3/datasource.go @@ -9,11 +9,12 @@ import ( "strings" "time" + _ "github.com/mattn/go-sqlite3" + "github.com/rs/zerolog/log" + + "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/model" "github.com/sjzar/chatlog/pkg/util" - - _ "github.com/mattn/go-sqlite3" - log "github.com/sirupsen/logrus" ) const ( @@ -45,19 +46,19 @@ func New(path string) (*DataSource, error) { } if err := ds.initMessageDbs(path); err != nil { - return nil, fmt.Errorf("初始化消息数据库失败: %w", err) + return nil, errors.DBInitFailed(err) } if err := ds.initContactDb(path); err != nil { - return nil, fmt.Errorf("初始化联系人数据库失败: %w", err) + return nil, errors.DBInitFailed(err) } if err := ds.initChatRoomDb(path); err != nil { - return nil, fmt.Errorf("初始化群聊数据库失败: %w", err) + return nil, errors.DBInitFailed(err) } if err := ds.initSessionDb(path); err != nil { - return nil, fmt.Errorf("初始化会话数据库失败: %w", err) + return nil, errors.DBInitFailed(err) } if err := ds.initMediaDb(path); err != nil { - return nil, fmt.Errorf("初始化会话数据库失败: %w", err) + return nil, errors.DBInitFailed(err) } return ds, nil @@ -67,11 +68,11 @@ func (ds *DataSource) initMessageDbs(path string) error { files, err := util.FindFilesWithPatterns(path, MessageFilePattern, true) if err != nil { - return fmt.Errorf("查找消息数据库文件失败: %w", err) + return errors.DBFileNotFound(path, MessageFilePattern, err) } if len(files) == 0 { - return fmt.Errorf("未找到任何消息数据库文件: %s", path) + return errors.DBFileNotFound(path, MessageFilePattern, nil) } // 处理每个数据库文件 @@ -79,7 +80,7 @@ func (ds *DataSource) initMessageDbs(path string) error { // 连接数据库 db, err := sql.Open("sqlite3", filePath) if err != nil { - log.Printf("警告: 连接数据库 %s 失败: %v", filePath, err) + log.Err(err).Msgf("连接数据库 %s 失败", filePath) continue } ds.messageDbs = append(ds.messageDbs, db) @@ -87,14 +88,14 @@ func (ds *DataSource) initMessageDbs(path string) error { // 获取所有表名 rows, err := db.Query("SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'Chat_%'") if err != nil { - log.Printf("警告: 获取表名失败: %v", err) + log.Err(err).Msgf("数据库 %s 中没有 Chat 表", filePath) continue } for rows.Next() { var tableName string if err := rows.Scan(&tableName); err != nil { - log.Printf("警告: 扫描表名失败: %v", err) + log.Err(err).Msgf("数据库 %s 扫描表名失败", filePath) continue } @@ -115,16 +116,16 @@ func (ds *DataSource) initContactDb(path string) error { files, err := util.FindFilesWithPatterns(path, ContactFilePattern, true) if err != nil { - return fmt.Errorf("查找联系人数据库文件失败: %w", err) + return errors.DBFileNotFound(path, ContactFilePattern, err) } if len(files) == 0 { - return fmt.Errorf("未找到联系人数据库文件: %s", path) + return errors.DBFileNotFound(path, ContactFilePattern, nil) } ds.contactDb, err = sql.Open("sqlite3", files[0]) if err != nil { - return fmt.Errorf("连接联系人数据库失败: %w", err) + return errors.DBConnectFailed(files[0], err) } return nil @@ -133,19 +134,19 @@ func (ds *DataSource) initContactDb(path string) error { func (ds *DataSource) initChatRoomDb(path string) error { files, err := util.FindFilesWithPatterns(path, ChatRoomFilePattern, true) if err != nil { - return fmt.Errorf("查找群聊数据库文件失败: %w", err) + return errors.DBFileNotFound(path, ChatRoomFilePattern, err) } if len(files) == 0 { - return fmt.Errorf("未找到群聊数据库文件: %s", path) + return errors.DBFileNotFound(path, ChatRoomFilePattern, nil) } ds.chatRoomDb, err = sql.Open("sqlite3", files[0]) if err != nil { - return fmt.Errorf("连接群聊数据库失败: %w", err) + return errors.DBConnectFailed(files[0], err) } rows, err := ds.chatRoomDb.Query("SELECT m_nsUsrName, IFNULL(nickname,\"\") FROM GroupMember") if err != nil { - log.Printf("警告: 获取群聊成员失败: %v", err) + log.Err(err).Msgf("数据库 %s 获取群聊成员失败", files[0]) return nil } @@ -153,7 +154,7 @@ func (ds *DataSource) initChatRoomDb(path string) error { var user string var nickName string if err := rows.Scan(&user, &nickName); err != nil { - log.Printf("警告: 扫描表名失败: %v", err) + log.Err(err).Msgf("数据库 %s 扫描表名失败", files[0]) continue } ds.user2DisplayName[user] = nickName @@ -166,14 +167,14 @@ func (ds *DataSource) initChatRoomDb(path string) error { func (ds *DataSource) initSessionDb(path string) error { files, err := util.FindFilesWithPatterns(path, SessionFilePattern, true) if err != nil { - return fmt.Errorf("查找最近会话数据库文件失败: %w", err) + return errors.DBFileNotFound(path, SessionFilePattern, err) } if len(files) == 0 { - return fmt.Errorf("未找到最近会话数据库文件: %s", path) + return errors.DBFileNotFound(path, SessionFilePattern, nil) } ds.sessionDb, err = sql.Open("sqlite3", files[0]) if err != nil { - return fmt.Errorf("连接最近会话数据库失败: %w", err) + return errors.DBConnectFailed(files[0], err) } return nil } @@ -181,14 +182,14 @@ func (ds *DataSource) initSessionDb(path string) error { func (ds *DataSource) initMediaDb(path string) error { files, err := util.FindFilesWithPatterns(path, MediaFilePattern, true) if err != nil { - return fmt.Errorf("查找媒体数据库文件失败: %w", err) + return errors.DBFileNotFound(path, MediaFilePattern, err) } if len(files) == 0 { - return fmt.Errorf("未找到媒体数据库文件: %s", path) + return errors.DBFileNotFound(path, MediaFilePattern, nil) } ds.mediaDb, err = sql.Open("sqlite3", files[0]) if err != nil { - return fmt.Errorf("连接媒体数据库失败: %w", err) + return errors.DBConnectFailed(files[0], err) } return nil } @@ -198,14 +199,14 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T // 在 darwinv3 中,每个联系人/群聊的消息存储在单独的表中,表名为 Chat_md5(talker) // 首先需要找到对应的表名 if talker == "" { - return nil, fmt.Errorf("talker 不能为空") + return nil, errors.ErrTalkerEmpty } _talkerMd5Bytes := md5.Sum([]byte(talker)) talkerMd5 := hex.EncodeToString(_talkerMd5Bytes[:]) db, ok := ds.talkerDBMap[talkerMd5] if !ok { - return nil, fmt.Errorf("未找到 talker %s 的消息数据库", talker) + return nil, errors.TalkerNotFound(talker) } tableName := fmt.Sprintf("Chat_%s", talkerMd5) @@ -228,7 +229,7 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T // 执行查询 rows, err := db.QueryContext(ctx, query, startTime.Unix(), endTime.Unix()) if err != nil { - return nil, fmt.Errorf("查询表 %s 失败: %w", tableName, err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -243,7 +244,7 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T &msg.MesDes, ) if err != nil { - log.Printf("警告: 扫描消息行失败: %v", err) + log.Err(err).Msgf("扫描消息行失败") continue } @@ -298,7 +299,7 @@ func (ds *DataSource) GetContacts(ctx context.Context, key string, limit, offset // 执行查询 rows, err := ds.contactDb.QueryContext(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("查询联系人失败: %w", err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -314,7 +315,7 @@ func (ds *DataSource) GetContacts(ctx context.Context, key string, limit, offset ) if err != nil { - return nil, fmt.Errorf("扫描联系人行失败: %w", err) + return nil, errors.ScanRowFailed(err) } contacts = append(contacts, contactDarwinV3.Wrap()) @@ -352,7 +353,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse // 执行查询 rows, err := ds.chatRoomDb.QueryContext(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("查询群聊失败: %w", err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -368,7 +369,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse ) if err != nil { - return nil, fmt.Errorf("扫描群聊行失败: %w", err) + return nil, errors.ScanRowFailed(err) } chatRooms = append(chatRooms, chatRoomDarwinV3.Wrap(ds.user2DisplayName)) @@ -386,7 +387,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse contacts[0].UserName) if err != nil { - return nil, fmt.Errorf("查询群聊失败: %w", err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -401,7 +402,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse ) if err != nil { - return nil, fmt.Errorf("扫描群聊行失败: %w", err) + return nil, errors.ScanRowFailed(err) } chatRooms = append(chatRooms, chatRoomDarwinV3.Wrap(ds.user2DisplayName)) @@ -449,7 +450,7 @@ func (ds *DataSource) GetSessions(ctx context.Context, key string, limit, offset // 执行查询 rows, err := ds.sessionDb.QueryContext(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("查询会话失败: %w", err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -462,7 +463,7 @@ func (ds *DataSource) GetSessions(ctx context.Context, key string, limit, offset ) if err != nil { - return nil, fmt.Errorf("扫描会话行失败: %w", err) + return nil, errors.ScanRowFailed(err) } // 包装成通用模型 @@ -488,7 +489,7 @@ func (ds *DataSource) GetSessions(ctx context.Context, key string, limit, offset func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (*model.Media, error) { if key == "" { - return nil, fmt.Errorf("key 不能为空") + return nil, errors.ErrKeyEmpty } query := `SELECT r.mediaMd5, @@ -507,7 +508,7 @@ WHERE // 执行查询 rows, err := ds.mediaDb.QueryContext(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("查询媒体失败: %w", err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -524,7 +525,7 @@ WHERE ) if err != nil { - return nil, fmt.Errorf("扫描会话行失败: %w", err) + return nil, errors.ScanRowFailed(err) } // 包装成通用模型 @@ -532,7 +533,7 @@ WHERE } if media == nil { - return nil, fmt.Errorf("未找到媒体 %s", key) + return nil, errors.ErrMediaNotFound } return media, nil @@ -543,42 +544,42 @@ func (ds *DataSource) Close() error { var errs []error // 关闭消息数据库连接 - for i, db := range ds.messageDbs { + for _, db := range ds.messageDbs { if err := db.Close(); err != nil { - errs = append(errs, fmt.Errorf("关闭消息数据库 %d 失败: %w", i, err)) + errs = append(errs, err) } } // 关闭联系人数据库连接 if ds.contactDb != nil { if err := ds.contactDb.Close(); err != nil { - errs = append(errs, fmt.Errorf("关闭联系人数据库失败: %w", err)) + errs = append(errs, err) } } // 关闭群聊数据库连接 if ds.chatRoomDb != nil { if err := ds.chatRoomDb.Close(); err != nil { - errs = append(errs, fmt.Errorf("关闭群聊数据库失败: %w", err)) + errs = append(errs, err) } } // 关闭会话数据库连接 if ds.sessionDb != nil { if err := ds.sessionDb.Close(); err != nil { - errs = append(errs, fmt.Errorf("关闭会话数据库失败: %w", err)) + errs = append(errs, err) } } // 关闭媒体数据库连接 if ds.mediaDb != nil { if err := ds.mediaDb.Close(); err != nil { - errs = append(errs, fmt.Errorf("关闭媒体数据库失败: %w", err)) + errs = append(errs, err) } } if len(errs) > 0 { - return fmt.Errorf("关闭数据库连接时发生错误: %v", errs) + return errors.DBCloseFailed(errs[0]) } return nil diff --git a/internal/wechatdb/datasource/datasource.go b/internal/wechatdb/datasource/datasource.go index 727969b..327362a 100644 --- a/internal/wechatdb/datasource/datasource.go +++ b/internal/wechatdb/datasource/datasource.go @@ -2,20 +2,15 @@ package datasource import ( "context" - "fmt" "time" + "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/model" "github.com/sjzar/chatlog/internal/wechatdb/datasource/darwinv3" v4 "github.com/sjzar/chatlog/internal/wechatdb/datasource/v4" "github.com/sjzar/chatlog/internal/wechatdb/datasource/windowsv3" ) -// 错误定义 -var ( - ErrUnsupportedPlatform = fmt.Errorf("unsupported platform") -) - type DataSource interface { // 消息 @@ -36,7 +31,7 @@ type DataSource interface { Close() error } -func NewDataSource(path string, platform string, version int) (DataSource, error) { +func New(path string, platform string, version int) (DataSource, error) { switch { case platform == "windows" && version == 3: return windowsv3.New(path) @@ -47,6 +42,6 @@ func NewDataSource(path string, platform string, version int) (DataSource, error case platform == "darwin" && version == 4: return v4.New(path) default: - return nil, fmt.Errorf("%w: %s v%d", ErrUnsupportedPlatform, platform, version) + return nil, errors.PlatformUnsupported(platform, version) } } diff --git a/internal/wechatdb/datasource/v4/datasource.go b/internal/wechatdb/datasource/v4/datasource.go index eb205ea..dfd2bb3 100644 --- a/internal/wechatdb/datasource/v4/datasource.go +++ b/internal/wechatdb/datasource/v4/datasource.go @@ -6,15 +6,16 @@ import ( "database/sql" "encoding/hex" "fmt" - "log" "sort" "strings" "time" + _ "github.com/mattn/go-sqlite3" + "github.com/rs/zerolog/log" + + "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/model" "github.com/sjzar/chatlog/pkg/util" - - _ "github.com/mattn/go-sqlite3" ) const ( @@ -51,16 +52,16 @@ func New(path string) (*DataSource, error) { } if err := ds.initMessageDbs(path); err != nil { - return nil, fmt.Errorf("初始化消息数据库失败: %w", err) + return nil, errors.DBInitFailed(err) } if err := ds.initContactDb(path); err != nil { - return nil, fmt.Errorf("初始化联系人数据库失败: %w", err) + return nil, errors.DBInitFailed(err) } if err := ds.initSessionDb(path); err != nil { - return nil, fmt.Errorf("初始化会话数据库失败: %w", err) + return nil, errors.DBInitFailed(err) } if err := ds.initMediaDb(path); err != nil { - return nil, fmt.Errorf("初始化媒体数据库失败: %w", err) + return nil, errors.DBInitFailed(err) } return ds, nil @@ -70,11 +71,11 @@ func (ds *DataSource) initMessageDbs(path string) error { // 查找所有消息数据库文件 files, err := util.FindFilesWithPatterns(path, MessageFilePattern, true) if err != nil { - return fmt.Errorf("查找消息数据库文件失败: %w", err) + return errors.DBFileNotFound(path, MessageFilePattern, err) } if len(files) == 0 { - return fmt.Errorf("未找到任何消息数据库文件: %s", path) + return errors.DBFileNotFound(path, MessageFilePattern, nil) } // 处理每个数据库文件 @@ -82,7 +83,7 @@ func (ds *DataSource) initMessageDbs(path string) error { // 连接数据库 db, err := sql.Open("sqlite3", filePath) if err != nil { - log.Printf("警告: 连接数据库 %s 失败: %v", filePath, err) + log.Err(err).Msgf("连接数据库 %s 失败", filePath) continue } @@ -92,7 +93,7 @@ func (ds *DataSource) initMessageDbs(path string) error { row := db.QueryRow("SELECT timestamp FROM Timestamp LIMIT 1") if err := row.Scan(×tamp); err != nil { - log.Printf("警告: 获取数据库 %s 的时间戳失败: %v", filePath, err) + log.Err(err).Msgf("获取数据库 %s 的时间戳失败", filePath) db.Close() continue } @@ -102,7 +103,7 @@ func (ds *DataSource) initMessageDbs(path string) error { id2Name := make(map[int]string) rows, err := db.Query("SELECT user_name FROM Name2Id") if err != nil { - log.Printf("警告: 获取数据库 %s 的 Name2Id 表失败: %v", filePath, err) + log.Err(err).Msgf("获取数据库 %s 的 Name2Id 表失败", filePath) db.Close() continue } @@ -111,7 +112,7 @@ func (ds *DataSource) initMessageDbs(path string) error { for rows.Next() { var name string if err := rows.Scan(&name); err != nil { - log.Printf("警告: 扫描 Name2Id 行失败: %v", err) + log.Err(err).Msgf("数据库 %s 扫描 Name2Id 行失败", filePath) continue } id2Name[i] = name @@ -150,16 +151,16 @@ func (ds *DataSource) initMessageDbs(path string) error { func (ds *DataSource) initContactDb(path string) error { files, err := util.FindFilesWithPatterns(path, ContactFilePattern, true) if err != nil { - return fmt.Errorf("查找联系人数据库文件失败: %w", err) + return errors.DBFileNotFound(path, ContactFilePattern, err) } if len(files) == 0 { - return fmt.Errorf("未找到联系人数据库文件: %s", path) + return errors.DBFileNotFound(path, ContactFilePattern, nil) } ds.contactDb, err = sql.Open("sqlite3", files[0]) if err != nil { - return fmt.Errorf("连接联系人数据库失败: %w", err) + return errors.DBConnectFailed(files[0], err) } return nil @@ -168,14 +169,14 @@ func (ds *DataSource) initContactDb(path string) error { func (ds *DataSource) initSessionDb(path string) error { files, err := util.FindFilesWithPatterns(path, SessionFilePattern, true) if err != nil { - return fmt.Errorf("查找最近会话数据库文件失败: %w", err) + return errors.DBFileNotFound(path, SessionFilePattern, err) } if len(files) == 0 { - return fmt.Errorf("未找到最近会话数据库文件: %s", path) + return errors.DBFileNotFound(path, SessionFilePattern, nil) } ds.sessionDb, err = sql.Open("sqlite3", files[0]) if err != nil { - return fmt.Errorf("连接最近会话数据库失败: %w", err) + return errors.DBConnectFailed(files[0], err) } return nil } @@ -183,14 +184,14 @@ func (ds *DataSource) initSessionDb(path string) error { func (ds *DataSource) initMediaDb(path string) error { files, err := util.FindFilesWithPatterns(path, MediaFilePattern, true) if err != nil { - return fmt.Errorf("查找媒体数据库文件失败: %w", err) + return errors.DBFileNotFound(path, MediaFilePattern, err) } if len(files) == 0 { - return fmt.Errorf("未找到媒体数据库文件: %s", path) + return errors.DBFileNotFound(path, MediaFilePattern, nil) } ds.mediaDb, err = sql.Open("sqlite3", files[0]) if err != nil { - return fmt.Errorf("连接媒体数据库失败: %w", err) + return errors.DBConnectFailed(files[0], err) } return nil } @@ -208,13 +209,13 @@ func (ds *DataSource) getDBInfosForTimeRange(startTime, endTime time.Time) []Mes func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) { if talker == "" { - return nil, fmt.Errorf("必须指定 talker 参数") + return nil, errors.ErrTalkerEmpty } // 找到时间范围内的数据库文件 dbInfos := ds.getDBInfosForTimeRange(startTime, endTime) if len(dbInfos) == 0 { - return nil, fmt.Errorf("未找到时间范围 %v 到 %v 内的数据库文件", startTime, endTime) + return nil, errors.TimeRangeNotFound(startTime, endTime) } if len(dbInfos) == 1 { @@ -233,13 +234,13 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T db, ok := ds.messageDbs[dbInfo.FilePath] if !ok { - log.Printf("警告: 数据库 %s 未打开", dbInfo.FilePath) + log.Error().Msgf("数据库 %s 未打开", dbInfo.FilePath) continue } messages, err := ds.getMessagesFromDB(ctx, db, dbInfo, startTime, endTime, talker) if err != nil { - log.Printf("警告: 从数据库 %s 获取消息失败: %v", dbInfo.FilePath, err) + log.Err(err).Msgf("从数据库 %s 获取消息失败", dbInfo.FilePath) continue } @@ -274,7 +275,7 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageDBInfo, startTime, endTime time.Time, talker string, limit, offset int) ([]*model.Message, error) { db, ok := ds.messageDbs[dbInfo.FilePath] if !ok { - return nil, fmt.Errorf("数据库 %s 未打开", dbInfo.FilePath) + return nil, errors.DBConnectFailed(dbInfo.FilePath, nil) } // 构建表名 @@ -303,7 +304,7 @@ func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageD // 执行查询 rows, err := db.QueryContext(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("查询数据库 %s 失败: %w", dbInfo.FilePath, err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -323,7 +324,7 @@ func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageD &msg.Status, ) if err != nil { - return nil, fmt.Errorf("扫描消息行失败: %w", err) + return nil, errors.ScanRowFailed(err) } messages = append(messages, msg.Wrap(dbInfo.ID2Name, isChatRoom)) @@ -350,7 +351,7 @@ func (ds *DataSource) getMessagesFromDB(ctx context.Context, db *sql.DB, dbInfo // 表不存在,返回空结果 return []*model.Message{}, nil } - return nil, fmt.Errorf("检查表 %s 是否存在失败: %w", tableName, err) + return nil, errors.QueryFailed("", err) } // 构建查询条件 @@ -371,7 +372,7 @@ func (ds *DataSource) getMessagesFromDB(ctx context.Context, db *sql.DB, dbInfo if strings.Contains(err.Error(), "no such table") { return []*model.Message{}, nil } - return nil, fmt.Errorf("查询数据库失败: %w", err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -391,7 +392,7 @@ func (ds *DataSource) getMessagesFromDB(ctx context.Context, db *sql.DB, dbInfo &msg.Status, ) if err != nil { - return nil, fmt.Errorf("扫描消息行失败: %w", err) + return nil, errors.ScanRowFailed(err) } messages = append(messages, msg.Wrap(dbInfo.ID2Name, isChatRoom)) @@ -428,7 +429,7 @@ func (ds *DataSource) GetContacts(ctx context.Context, key string, limit, offset // 执行查询 rows, err := ds.contactDb.QueryContext(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("查询联系人失败: %w", err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -444,7 +445,7 @@ func (ds *DataSource) GetContacts(ctx context.Context, key string, limit, offset ) if err != nil { - return nil, fmt.Errorf("扫描联系人行失败: %w", err) + return nil, errors.ScanRowFailed(err) } contacts = append(contacts, contactV4.Wrap()) @@ -466,7 +467,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse // 执行查询 rows, err := ds.contactDb.QueryContext(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("查询群聊失败: %w", err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -480,7 +481,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse ) if err != nil { - return nil, fmt.Errorf("扫描群聊行失败: %w", err) + return nil, errors.ScanRowFailed(err) } chatRooms = append(chatRooms, chatRoomV4.Wrap()) @@ -496,7 +497,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse contacts[0].UserName) if err != nil { - return nil, fmt.Errorf("查询群聊失败: %w", err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -509,7 +510,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse ) if err != nil { - return nil, fmt.Errorf("扫描群聊行失败: %w", err) + return nil, errors.ScanRowFailed(err) } chatRooms = append(chatRooms, chatRoomV4.Wrap()) @@ -543,7 +544,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse // 执行查询 rows, err := ds.contactDb.QueryContext(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("查询群聊失败: %w", err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -557,7 +558,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse ) if err != nil { - return nil, fmt.Errorf("扫描群聊行失败: %w", err) + return nil, errors.ScanRowFailed(err) } chatRooms = append(chatRooms, chatRoomV4.Wrap()) @@ -597,7 +598,7 @@ func (ds *DataSource) GetSessions(ctx context.Context, key string, limit, offset // 执行查询 rows, err := ds.sessionDb.QueryContext(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("查询会话失败: %w", err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -613,7 +614,7 @@ func (ds *DataSource) GetSessions(ctx context.Context, key string, limit, offset ) if err != nil { - return nil, fmt.Errorf("扫描会话行失败: %w", err) + return nil, errors.ScanRowFailed(err) } sessions = append(sessions, sessionV4.Wrap()) @@ -624,11 +625,11 @@ func (ds *DataSource) GetSessions(ctx context.Context, key string, limit, offset func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (*model.Media, error) { if key == "" { - return nil, fmt.Errorf("key 不能为空") + return nil, errors.ErrKeyEmpty } if len(key) != 32 { - return nil, fmt.Errorf("key 长度必须为 32") + return nil, errors.ErrKeyLengthMust32 } var table string @@ -640,7 +641,7 @@ func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (* case "file": table = "file_hardlink_info_v3" default: - return nil, fmt.Errorf("不支持的媒体类型: %s", _type) + return nil, errors.MediaTypeUnsupported(_type) } query := fmt.Sprintf(` @@ -663,7 +664,7 @@ func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (* rows, err := ds.mediaDb.QueryContext(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("查询媒体失败: %w", err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -679,7 +680,7 @@ func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (* &mediaV4.Dir2, ) if err != nil { - return nil, fmt.Errorf("扫描会话行失败: %w", err) + return nil, errors.ScanRowFailed(err) } mediaV4.Type = _type media = mediaV4.Wrap() @@ -691,7 +692,7 @@ func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (* } if media == nil { - return nil, fmt.Errorf("未找到媒体 %s", key) + return nil, errors.ErrMediaNotFound } return media, nil @@ -701,34 +702,34 @@ func (ds *DataSource) Close() error { var errs []error // 关闭消息数据库连接 - for path, db := range ds.messageDbs { + for _, db := range ds.messageDbs { if err := db.Close(); err != nil { - errs = append(errs, fmt.Errorf("关闭消息数据库 %s 失败: %w", path, err)) + errs = append(errs, err) } } // 关闭联系人数据库连接 if ds.contactDb != nil { if err := ds.contactDb.Close(); err != nil { - errs = append(errs, fmt.Errorf("关闭联系人数据库失败: %w", err)) + errs = append(errs, err) } } // 关闭会话数据库连接 if ds.sessionDb != nil { if err := ds.sessionDb.Close(); err != nil { - errs = append(errs, fmt.Errorf("关闭会话数据库失败: %w", err)) + errs = append(errs, err) } } if ds.mediaDb != nil { if err := ds.mediaDb.Close(); err != nil { - errs = append(errs, fmt.Errorf("关闭媒体数据库失败: %w", err)) + errs = append(errs, err) } } if len(errs) > 0 { - return fmt.Errorf("关闭数据库连接时发生错误: %v", errs) + return errors.DBCloseFailed(errs[0]) } return nil diff --git a/internal/wechatdb/datasource/windowsv3/datasource.go b/internal/wechatdb/datasource/windowsv3/datasource.go index 20d857f..a74dd05 100644 --- a/internal/wechatdb/datasource/windowsv3/datasource.go +++ b/internal/wechatdb/datasource/windowsv3/datasource.go @@ -5,15 +5,16 @@ import ( "database/sql" "encoding/hex" "fmt" - "log" "sort" "strings" "time" + _ "github.com/mattn/go-sqlite3" + "github.com/rs/zerolog/log" + + "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/model" "github.com/sjzar/chatlog/pkg/util" - - _ "github.com/mattn/go-sqlite3" ) const ( @@ -56,16 +57,16 @@ func New(path string) (*DataSource, error) { // 初始化消息数据库 if err := ds.initMessageDbs(path); err != nil { - return nil, fmt.Errorf("初始化消息数据库失败: %w", err) + return nil, errors.DBInitFailed(err) } // 初始化联系人数据库 if err := ds.initContactDb(path); err != nil { - return nil, fmt.Errorf("初始化联系人数据库失败: %w", err) + return nil, errors.DBInitFailed(err) } if err := ds.initMediaDb(path); err != nil { - return nil, fmt.Errorf("初始化多媒体数据库失败: %w", err) + return nil, errors.DBInitFailed(err) } return ds, nil @@ -76,11 +77,11 @@ func (ds *DataSource) initMessageDbs(path string) error { // 查找所有消息数据库文件 files, err := util.FindFilesWithPatterns(path, MessageFilePattern, true) if err != nil { - return fmt.Errorf("查找消息数据库文件失败: %w", err) + return errors.DBFileNotFound(path, MessageFilePattern, err) } if len(files) == 0 { - return fmt.Errorf("未找到任何消息数据库文件: %s", path) + return errors.DBFileNotFound(path, MessageFilePattern, nil) } // 处理每个数据库文件 @@ -88,7 +89,7 @@ func (ds *DataSource) initMessageDbs(path string) error { // 连接数据库 db, err := sql.Open("sqlite3", filePath) if err != nil { - log.Printf("警告: 连接数据库 %s 失败: %v", filePath, err) + log.Err(err).Msgf("连接数据库 %s 失败", filePath) continue } @@ -97,7 +98,7 @@ func (ds *DataSource) initMessageDbs(path string) error { rows, err := db.Query("SELECT tableIndex, tableVersion, tableDesc FROM DBInfo") if err != nil { - log.Printf("警告: 查询数据库 %s 的 DBInfo 表失败: %v", filePath, err) + log.Err(err).Msgf("查询数据库 %s 的 DBInfo 表失败", filePath) db.Close() continue } @@ -108,7 +109,7 @@ func (ds *DataSource) initMessageDbs(path string) error { var tableDesc string if err := rows.Scan(&tableIndex, &tableVersion, &tableDesc); err != nil { - log.Printf("警告: 扫描 DBInfo 行失败: %v", err) + log.Err(err).Msg("扫描 DBInfo 行失败") continue } @@ -124,7 +125,7 @@ func (ds *DataSource) initMessageDbs(path string) error { talkerMap := make(map[string]int) rows, err = db.Query("SELECT UsrName FROM Name2ID") if err != nil { - log.Printf("警告: 查询数据库 %s 的 Name2ID 表失败: %v", filePath, err) + log.Err(err).Msgf("查询数据库 %s 的 Name2ID 表失败", filePath) db.Close() continue } @@ -133,7 +134,7 @@ func (ds *DataSource) initMessageDbs(path string) error { for rows.Next() { var userName string if err := rows.Scan(&userName); err != nil { - log.Printf("警告: 扫描 Name2ID 行失败: %v", err) + log.Err(err).Msg("扫描 Name2ID 行失败") continue } talkerMap[userName] = i @@ -173,18 +174,18 @@ func (ds *DataSource) initMessageDbs(path string) error { func (ds *DataSource) initContactDb(path string) error { files, err := util.FindFilesWithPatterns(path, ContactFilePattern, true) if err != nil { - return fmt.Errorf("查找联系人数据库文件失败: %w", err) + return errors.DBFileNotFound(path, ContactFilePattern, err) } if len(files) == 0 { - return fmt.Errorf("未找到联系人数据库文件: %s", path) + return errors.DBFileNotFound(path, ContactFilePattern, nil) } ds.contactDbFile = files[0] ds.contactDb, err = sql.Open("sqlite3", ds.contactDbFile) if err != nil { - return fmt.Errorf("连接联系人数据库失败: %w", err) + return errors.DBConnectFailed(ds.contactDbFile, err) } return nil @@ -194,44 +195,44 @@ func (ds *DataSource) initContactDb(path string) error { func (ds *DataSource) initMediaDb(path string) error { files, err := util.FindFilesWithPatterns(path, ImageFilePattern, true) if err != nil { - return fmt.Errorf("查找图片数据库文件失败: %w", err) + return errors.DBFileNotFound(path, ImageFilePattern, err) } if len(files) == 0 { - return fmt.Errorf("未找到图片数据库文件: %s", path) + return errors.DBFileNotFound(path, ImageFilePattern, nil) } ds.imageDb, err = sql.Open("sqlite3", files[0]) if err != nil { - return fmt.Errorf("连接图片数据库失败: %w", err) + return errors.DBConnectFailed(files[0], err) } files, err = util.FindFilesWithPatterns(path, VideoFilePattern, true) if err != nil { - return fmt.Errorf("查找视频数据库文件失败: %w", err) + return errors.DBFileNotFound(path, VideoFilePattern, err) } if len(files) == 0 { - return fmt.Errorf("未找到视频数据库文件: %s", path) + return errors.DBFileNotFound(path, VideoFilePattern, nil) } ds.videoDb, err = sql.Open("sqlite3", files[0]) if err != nil { - return fmt.Errorf("连接视频数据库失败: %w", err) + return errors.DBConnectFailed(files[0], err) } files, err = util.FindFilesWithPatterns(path, FileFilePattern, true) if err != nil { - return fmt.Errorf("查找文件数据库文件失败: %w", err) + return errors.DBFileNotFound(path, FileFilePattern, err) } if len(files) == 0 { - return fmt.Errorf("未找到文件数据库文件: %s", path) + return errors.DBFileNotFound(path, FileFilePattern, nil) } ds.fileDb, err = sql.Open("sqlite3", files[0]) if err != nil { - return fmt.Errorf("连接文件数据库失败: %w", err) + return errors.DBConnectFailed(files[0], err) } return nil @@ -253,7 +254,7 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T // 找到时间范围内的数据库文件 dbInfos := ds.getDBInfosForTimeRange(startTime, endTime) if len(dbInfos) == 0 { - return nil, fmt.Errorf("未找到时间范围 %v 到 %v 内的数据库文件", startTime, endTime) + return nil, errors.TimeRangeNotFound(startTime, endTime) } if len(dbInfos) == 1 { @@ -272,7 +273,7 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T db, ok := ds.messageDbs[dbInfo.FilePath] if !ok { - log.Printf("警告: 数据库 %s 未打开", dbInfo.FilePath) + log.Error().Msgf("数据库 %s 未打开", dbInfo.FilePath) continue } @@ -302,7 +303,7 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T // 执行查询 rows, err := db.QueryContext(ctx, query, args...) if err != nil { - log.Printf("警告: 查询数据库 %s 失败: %v", dbInfo.FilePath, err) + log.Err(err).Msgf("查询数据库 %s 失败", dbInfo.FilePath) continue } @@ -325,7 +326,7 @@ func (ds *DataSource) GetMessages(ctx context.Context, startTime, endTime time.T &bytesExtra, ) if err != nil { - log.Printf("警告: 扫描消息行失败: %v", err) + log.Err(err).Msg("扫描消息行失败") continue } msg.CompressContent = compressContent @@ -395,7 +396,7 @@ func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageD // 执行查询 rows, err := ds.messageDbs[dbInfo.FilePath].QueryContext(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("查询数据库 %s 失败: %w", dbInfo.FilePath, err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -418,7 +419,7 @@ func (ds *DataSource) getMessagesSingleFile(ctx context.Context, dbInfo MessageD &bytesExtra, ) if err != nil { - return nil, fmt.Errorf("扫描消息行失败: %w", err) + return nil, errors.ScanRowFailed(err) } msg.CompressContent = compressContent msg.BytesExtra = bytesExtra @@ -454,7 +455,7 @@ func (ds *DataSource) GetContacts(ctx context.Context, key string, limit, offset // 执行查询 rows, err := ds.contactDb.QueryContext(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("查询联系人失败: %w", err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -470,7 +471,7 @@ func (ds *DataSource) GetContacts(ctx context.Context, key string, limit, offset ) if err != nil { - return nil, fmt.Errorf("扫描联系人行失败: %w", err) + return nil, errors.ScanRowFailed(err) } contacts = append(contacts, contactV3.Wrap()) @@ -492,7 +493,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse // 执行查询 rows, err := ds.contactDb.QueryContext(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("查询群聊失败: %w", err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -506,7 +507,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse ) if err != nil { - return nil, fmt.Errorf("扫描群聊行失败: %w", err) + return nil, errors.ScanRowFailed(err) } chatRooms = append(chatRooms, chatRoomV3.Wrap()) @@ -522,7 +523,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse contacts[0].UserName) if err != nil { - return nil, fmt.Errorf("查询群聊失败: %w", err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -535,7 +536,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse ) if err != nil { - return nil, fmt.Errorf("扫描群聊行失败: %w", err) + return nil, errors.ScanRowFailed(err) } chatRooms = append(chatRooms, chatRoomV3.Wrap()) @@ -569,7 +570,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse // 执行查询 rows, err := ds.contactDb.QueryContext(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("查询群聊失败: %w", err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -583,7 +584,7 @@ func (ds *DataSource) GetChatRooms(ctx context.Context, key string, limit, offse ) if err != nil { - return nil, fmt.Errorf("扫描群聊行失败: %w", err) + return nil, errors.ScanRowFailed(err) } chatRooms = append(chatRooms, chatRoomV3.Wrap()) @@ -623,7 +624,7 @@ func (ds *DataSource) GetSessions(ctx context.Context, key string, limit, offset // 执行查询 rows, err := ds.contactDb.QueryContext(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("查询会话失败: %w", err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -639,7 +640,7 @@ func (ds *DataSource) GetSessions(ctx context.Context, key string, limit, offset ) if err != nil { - return nil, fmt.Errorf("扫描会话行失败: %w", err) + return nil, errors.ScanRowFailed(err) } sessions = append(sessions, sessionV3.Wrap()) @@ -650,12 +651,12 @@ func (ds *DataSource) GetSessions(ctx context.Context, key string, limit, offset func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (*model.Media, error) { if key == "" { - return nil, fmt.Errorf("key 不能为空") + return nil, errors.ErrKeyEmpty } md5key, err := hex.DecodeString(key) if err != nil { - return nil, fmt.Errorf("解析 key 失败: %w", err) + return nil, errors.DecodeKeyFailed(err) } var db *sql.DB @@ -675,7 +676,7 @@ func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (* table1 = "HardLinkFileAttribute" table2 = "HardLinkFileID" default: - return nil, fmt.Errorf("不支持的媒体类型: %s", _type) + return nil, errors.MediaTypeUnsupported(_type) } @@ -698,7 +699,7 @@ func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (* rows, err := db.QueryContext(ctx, query, args...) if err != nil { - return nil, fmt.Errorf("查询媒体失败: %w", err) + return nil, errors.QueryFailed(query, err) } defer rows.Close() @@ -712,7 +713,7 @@ func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (* &mediaV3.Dir2, ) if err != nil { - return nil, fmt.Errorf("扫描会话行失败: %w", err) + return nil, errors.ScanRowFailed(err) } mediaV3.Type = _type mediaV3.Key = key @@ -720,7 +721,7 @@ func (ds *DataSource) GetMedia(ctx context.Context, _type string, key string) (* } if media == nil { - return nil, fmt.Errorf("未找到媒体 %s", key) + return nil, errors.ErrMediaNotFound } return media, nil @@ -731,37 +732,37 @@ func (ds *DataSource) Close() error { var errs []error // 关闭消息数据库连接 - for path, db := range ds.messageDbs { + for _, db := range ds.messageDbs { if err := db.Close(); err != nil { - errs = append(errs, fmt.Errorf("关闭消息数据库 %s 失败: %w", path, err)) + errs = append(errs, err) } } // 关闭联系人数据库连接 if ds.contactDb != nil { if err := ds.contactDb.Close(); err != nil { - errs = append(errs, fmt.Errorf("关闭联系人数据库失败: %w", err)) + errs = append(errs, err) } } if ds.imageDb != nil { if err := ds.imageDb.Close(); err != nil { - errs = append(errs, fmt.Errorf("关闭图片数据库失败: %w", err)) + errs = append(errs, err) } } if ds.videoDb != nil { if err := ds.videoDb.Close(); err != nil { - errs = append(errs, fmt.Errorf("关闭视频数据库失败: %w", err)) + errs = append(errs, err) } } if ds.fileDb != nil { if err := ds.fileDb.Close(); err != nil { - errs = append(errs, fmt.Errorf("关闭文件数据库失败: %w", err)) + errs = append(errs, err) } } if len(errs) > 0 { - return fmt.Errorf("关闭数据库连接时发生错误: %v", errs) + return errors.DBCloseFailed(errs[0]) } return nil diff --git a/internal/wechatdb/repository/chatroom.go b/internal/wechatdb/repository/chatroom.go index fd2e399..04849f3 100644 --- a/internal/wechatdb/repository/chatroom.go +++ b/internal/wechatdb/repository/chatroom.go @@ -2,10 +2,10 @@ package repository import ( "context" - "fmt" "sort" "strings" + "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/model" ) @@ -14,7 +14,7 @@ func (r *Repository) initChatRoomCache(ctx context.Context) error { // 加载所有群聊到缓存 chatRooms, err := r.ds.GetChatRooms(ctx, "", 0, 0) if err != nil { - return fmt.Errorf("加载群聊失败: %w", err) + return err } chatRoomMap := make(map[string]*model.ChatRoom) @@ -75,7 +75,7 @@ func (r *Repository) GetChatRooms(ctx context.Context, key string, limit, offset if key != "" { ret = r.findChatRooms(key) if len(ret) == 0 { - return nil, fmt.Errorf("未找到群聊: %s", key) + return nil, errors.ChatRoomNotFound(key) } if limit > 0 { @@ -111,7 +111,7 @@ func (r *Repository) GetChatRooms(ctx context.Context, key string, limit, offset func (r *Repository) GetChatRoom(ctx context.Context, key string) (*model.ChatRoom, error) { chatRoom := r.findChatRoom(key) if chatRoom == nil { - return nil, fmt.Errorf("未找到群聊: %s", key) + return nil, errors.ChatRoomNotFound(key) } return chatRoom, nil } diff --git a/internal/wechatdb/repository/contact.go b/internal/wechatdb/repository/contact.go index ca40c8d..5d641fa 100644 --- a/internal/wechatdb/repository/contact.go +++ b/internal/wechatdb/repository/contact.go @@ -2,10 +2,10 @@ package repository import ( "context" - "fmt" "sort" "strings" + "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/model" ) @@ -14,7 +14,7 @@ func (r *Repository) initContactCache(ctx context.Context) error { // 加载所有联系人到缓存 contacts, err := r.ds.GetContacts(ctx, "", 0, 0) if err != nil { - return fmt.Errorf("加载联系人失败: %w", err) + return err } contactMap := make(map[string]*model.Contact) @@ -78,7 +78,7 @@ func (r *Repository) GetContact(ctx context.Context, key string) (*model.Contact // 先尝试从缓存中获取 contact := r.findContact(key) if contact == nil { - return nil, fmt.Errorf("未找到联系人: %s", key) + return nil, errors.ContactNotFound(key) } return contact, nil } @@ -88,7 +88,7 @@ func (r *Repository) GetContacts(ctx context.Context, key string, limit, offset if key != "" { ret = r.findContacts(key) if len(ret) == 0 { - return nil, fmt.Errorf("未找到联系人: %s", key) + return nil, errors.ContactNotFound(key) } if limit > 0 { end := offset + limit diff --git a/internal/wechatdb/repository/message.go b/internal/wechatdb/repository/message.go index 597f809..28e60d3 100644 --- a/internal/wechatdb/repository/message.go +++ b/internal/wechatdb/repository/message.go @@ -6,7 +6,7 @@ import ( "github.com/sjzar/chatlog/internal/model" - log "github.com/sirupsen/logrus" + "github.com/rs/zerolog/log" ) // GetMessages 实现 Repository 接口的 GetMessages 方法 @@ -25,7 +25,7 @@ func (r *Repository) GetMessages(ctx context.Context, startTime, endTime time.Ti // 补充消息信息 if err := r.EnrichMessages(ctx, messages); err != nil { - log.Debugf("EnrichMessages failed: %v", err) + log.Debug().Msgf("EnrichMessages failed: %v", err) } return messages, nil diff --git a/internal/wechatdb/repository/repository.go b/internal/wechatdb/repository/repository.go index aefe0d5..6fa8ce3 100644 --- a/internal/wechatdb/repository/repository.go +++ b/internal/wechatdb/repository/repository.go @@ -2,8 +2,8 @@ package repository import ( "context" - "fmt" + "github.com/sjzar/chatlog/internal/errors" "github.com/sjzar/chatlog/internal/model" "github.com/sjzar/chatlog/internal/wechatdb/datasource" ) @@ -58,7 +58,7 @@ func New(ds datasource.DataSource) (*Repository, error) { // 初始化缓存 if err := r.initCache(context.Background()); err != nil { - return nil, fmt.Errorf("初始化缓存失败: %w", err) + return nil, errors.InitCacheFailed(err) } return r, nil diff --git a/internal/wechatdb/wechatdb.go b/internal/wechatdb/wechatdb.go index 293a904..3c80cb7 100644 --- a/internal/wechatdb/wechatdb.go +++ b/internal/wechatdb/wechatdb.go @@ -2,7 +2,6 @@ package wechatdb import ( "context" - "fmt" "time" "github.com/sjzar/chatlog/internal/model" @@ -45,14 +44,14 @@ func (w *DB) Close() error { func (w *DB) Initialize() error { var err error - w.ds, err = datasource.NewDataSource(w.path, w.platform, w.version) + w.ds, err = datasource.New(w.path, w.platform, w.version) if err != nil { - return fmt.Errorf("初始化数据源失败: %w", err) + return err } w.repo, err = repository.New(w.ds) if err != nil { - return fmt.Errorf("初始化仓库失败: %w", err) + return err } return nil @@ -64,7 +63,7 @@ func (w *DB) GetMessages(start, end time.Time, talker string, limit, offset int) // 使用 repository 获取消息 messages, err := w.repo.GetMessages(ctx, start, end, talker, limit, offset) if err != nil { - return nil, fmt.Errorf("获取消息失败: %w", err) + return nil, err } return messages, nil @@ -114,7 +113,7 @@ func (w *DB) GetSessions(key string, limit, offset int) (*GetSessionsResp, error // 使用 repository 获取会话列表 sessions, err := w.repo.GetSessions(ctx, key, limit, offset) if err != nil { - return nil, fmt.Errorf("获取会话列表失败: %w", err) + return nil, err } return &GetSessionsResp{ diff --git a/pkg/config/config.go b/pkg/config/config.go index 571849f..489ba5b 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -20,7 +20,7 @@ import ( "errors" "os" - log "github.com/sirupsen/logrus" + "github.com/rs/zerolog/log" "github.com/spf13/viper" ) @@ -141,7 +141,7 @@ func PrepareDir(path string) error { return err } } else if !stat.IsDir() { - log.Debugf("%s is not a directory", path) + log.Debug().Msgf("%s is not a directory", path) return ErrInvalidDirectory } return nil diff --git a/pkg/util/os.go b/pkg/util/os.go index ae4ec4c..7cb25e9 100644 --- a/pkg/util/os.go +++ b/pkg/util/os.go @@ -8,7 +8,7 @@ import ( "regexp" "runtime" - log "github.com/sirupsen/logrus" + "github.com/rs/zerolog/log" ) // FindFilesWithPatterns 在指定目录下查找匹配多个正则表达式的文件 @@ -128,7 +128,7 @@ func PrepareDir(path string) error { return err } } else if !stat.IsDir() { - log.Debugf("%s is not a directory", path) + log.Debug().Msgf("%s is not a directory", path) return fmt.Errorf("%s is not a directory", path) } return nil