Skip to content

Commit

Permalink
Bug Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Jinnrry committed Sep 9, 2023
1 parent 9c3dccd commit e6ddfba
Show file tree
Hide file tree
Showing 12 changed files with 130 additions and 85 deletions.
Empty file modified build.sh
100644 → 100755
Empty file.
2 changes: 1 addition & 1 deletion server/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type Config struct {
//go:embed tables/*
var tableConfig embed.FS

const Version = "2.2.1"
const Version = "2.2.2"

const DBTypeMySQL = "mysql"
const DBTypeSQLite = "sqlite"
Expand Down
8 changes: 4 additions & 4 deletions server/http_server/http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ func HttpStart() {
httpServer = &http.Server{
Addr: fmt.Sprintf(":%d", HttpPort),
Handler: mux,
ReadTimeout: time.Second * 60,
WriteTimeout: time.Second * 60,
ReadTimeout: time.Second * 90,
WriteTimeout: time.Second * 90,
}
} else {
fe, err := fs.Sub(local, "dist")
Expand Down Expand Up @@ -64,8 +64,8 @@ func HttpStart() {
httpServer = &http.Server{
Addr: fmt.Sprintf(":%d", HttpPort),
Handler: session.Instance.LoadAndSave(mux),
ReadTimeout: time.Second * 60,
WriteTimeout: time.Second * 60,
ReadTimeout: time.Second * 90,
WriteTimeout: time.Second * 90,
}
}

Expand Down
33 changes: 4 additions & 29 deletions server/http_server/https_server.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
package http_server

import (
"bytes"
"embed"
"encoding/hex"
"encoding/json"
"fmt"
log "github.com/sirupsen/logrus"
"github.com/spf13/cast"
"io/fs"
olog "log"
"math/rand"
"net"
"net/http"
"os"
"pmail/config"
"pmail/controllers"
"pmail/controllers/email"
Expand All @@ -22,6 +17,7 @@ import (
"pmail/models"
"pmail/session"
"pmail/utils/context"
"pmail/utils/id"
"time"
)

Expand Down Expand Up @@ -80,8 +76,8 @@ func HttpsStart() {
httpsServer = &http.Server{
Addr: fmt.Sprintf(":%d", HttpsPort),
Handler: session.Instance.LoadAndSave(mux),
ReadTimeout: time.Second * 60,
WriteTimeout: time.Second * 60,
ReadTimeout: time.Second * 90,
WriteTimeout: time.Second * 90,
ErrorLog: nullLog,
}
err = httpsServer.ListenAndServeTLS("config/ssl/public.crt", "config/ssl/private.key")
Expand All @@ -97,27 +93,6 @@ func HttpsStop() {
}
}

func genLogID() string {
r := rand.New(rand.NewSource(time.Now().UnixMicro()))
if ip == "" {
ip = getLocalIP()
}
now := time.Now()
timestamp := uint32(now.Unix())
timeNano := now.UnixNano()
pid := os.Getpid()
b := bytes.Buffer{}

b.WriteString(hex.EncodeToString(net.ParseIP(ip).To4()))
b.WriteString(fmt.Sprintf("%x", timestamp&0xffffffff))
b.WriteString(fmt.Sprintf("%04x", timeNano&0xffff))
b.WriteString(fmt.Sprintf("%04x", pid&0xffff))
b.WriteString(fmt.Sprintf("%06x", r.Int31n(1<<24)))
b.WriteString("b0")

return b.String()
}

// 注入context
func contextIterceptor(h controllers.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
Expand All @@ -127,7 +102,7 @@ func contextIterceptor(h controllers.HandlerFunc) http.HandlerFunc {

ctx := &context.Context{}
ctx.Context = r.Context()
ctx.SetValue(context.LogID, genLogID())
ctx.SetValue(context.LogID, id.GenLogID())
lang := r.Header.Get("Lang")
if lang == "" {
lang = "en"
Expand Down
19 changes: 0 additions & 19 deletions server/http_server/setup_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package http_server
import (
"fmt"
"io/fs"
"net"
"net/http"
"pmail/config"
"pmail/controllers"
Expand Down Expand Up @@ -49,21 +48,3 @@ func SetupStop() {
panic(err)
}
}

func getLocalIP() string {
ip := "127.0.0.1"
addrs, err := net.InterfaceAddrs()
if err != nil {
return ip
}
for _, a := range addrs {
if ipnet, ok := a.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
if ipnet.IP.To4() != nil {
ip = ipnet.IP.String()
break
}
}
}

return ip
}
5 changes: 4 additions & 1 deletion server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ func (l *logFormatter) Format(entry *log.Entry) ([]byte, error) {
b.WriteString(fmt.Sprintf("[%s]", entry.Level.String()))
b.WriteString(fmt.Sprintf("[%s]", entry.Time.Format("2006-01-02 15:04:05")))
if entry.Context != nil {
b.WriteString(fmt.Sprintf("[%s]", entry.Context.(*context.Context).GetValue(context.LogID)))
ctx := entry.Context.(*context.Context)
if ctx != nil {
b.WriteString(fmt.Sprintf("[%s]", ctx.GetValue(context.LogID)))
}
}
b.WriteString(fmt.Sprintf("[%s:%d]", entry.Caller.File, entry.Caller.Line))
b.WriteString(entry.Message)
Expand Down
7 changes: 4 additions & 3 deletions server/services/rule/rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (
func GetAllRules(ctx *context.Context) []*dto.Rule {
var res []*models.Rule
var err error
if ctx == nil {
if ctx == nil || ctx.UserID == 0 {
err = db.Instance.Select(&res, "select * from rule order by sort desc")
} else {
err = db.Instance.Select(&res, db.WithContext(ctx, "select * from rule where user_id=? order by sort desc"), ctx.UserID)
Expand Down Expand Up @@ -60,6 +60,8 @@ func MatchRule(ctx *context.Context, rule *dto.Rule, email *parsemail.Email) boo
}

func DoRule(ctx *context.Context, rule *dto.Rule, email *parsemail.Email) {
log.WithContext(ctx).Debugf("执行规则:%s", rule.Name)

switch rule.Action {
case dto.READ:
email.IsRead = 1
Expand All @@ -70,8 +72,7 @@ func DoRule(ctx *context.Context, rule *dto.Rule, email *parsemail.Email) {
log.WithContext(ctx).Errorf("Forward Error! loop forwarding!")
return
}

err := send.Forward(nil, email, rule.Params)
err := send.Forward(ctx, email, rule.Params)
if err != nil {
log.WithContext(ctx).Errorf("Forward Error:%v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion server/services/setup/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ func GetDatabaseSettings(ctx *context.Context) (string, string, error) {
}

if configData.DbType == "" && configData.DbDSN == "" {
return config.DBTypeSQLite, "./pmail.db", nil
return config.DBTypeSQLite, "./config/pmail.db", nil
}

return configData.DbType, configData.DbDSN, nil
Expand Down
31 changes: 20 additions & 11 deletions server/smtp_server/read_content.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,24 @@ import (
"pmail/hooks"
"pmail/services/rule"
"pmail/utils/async"
"pmail/utils/context"
"pmail/utils/id"
"strings"
"time"
)

func (s *Session) Data(r io.Reader) error {
ctx := &context.Context{}
ctx.SetValue(context.LogID, id.GenLogID())
log.WithContext(ctx).Debugf("收到邮件")

emailData, err := io.ReadAll(r)
if err != nil {
log.Error("邮件内容无法读取", err)
log.WithContext(ctx).Error("邮件内容无法读取", err)
return err
}

as1 := async.New(nil)
as1 := async.New(ctx)
for _, hook := range hooks.HookList {
if hook == nil {
continue
Expand All @@ -36,7 +42,7 @@ func (s *Session) Data(r io.Reader) error {
}
as1.Wait()

log.Infof("邮件原始内容: %s", emailData)
log.WithContext(ctx).Infof("邮件原始内容: %s", emailData)

var dkimStatus, SPFStatus bool

Expand All @@ -46,7 +52,7 @@ func (s *Session) Data(r io.Reader) error {
email := parsemail.NewEmailFromReader(bytes.NewReader(emailData))

if err != nil {
log.Fatalf("邮件内容解析失败! Error : %v \n", err)
log.WithContext(ctx).Errorf("邮件内容解析失败! Error : %v \n", err)
}

SPFStatus = spfCheck(s.RemoteAddress.String(), email.Sender, email.Sender.EmailAddress)
Expand All @@ -61,16 +67,17 @@ func (s *Session) Data(r io.Reader) error {

// 垃圾过滤
if config.Instance.SpamFilterLevel == 1 && !SPFStatus && !dkimStatus {
log.Infoln("垃圾邮件,拒信")
log.WithContext(ctx).Infoln("垃圾邮件,拒信")
return nil
}

if config.Instance.SpamFilterLevel == 2 && !SPFStatus {
log.Infoln("垃圾邮件,拒信")
log.WithContext(ctx).Infoln("垃圾邮件,拒信")
return nil
}
log.WithContext(ctx).Debugf("开始执行插件!")

as2 := async.New(nil)
as2 := async.New(ctx)
for _, hook := range hooks.HookList {
if hook == nil {
continue
Expand All @@ -81,13 +88,15 @@ func (s *Session) Data(r io.Reader) error {
}
as2.Wait()

log.WithContext(ctx).Debugf("开始执行邮件规则!")
// 执行邮件规则
rs := rule.GetAllRules(nil)
rs := rule.GetAllRules(ctx)
for _, r := range rs {
if rule.MatchRule(nil, r, email) {
rule.DoRule(nil, r, email)
if rule.MatchRule(ctx, r, email) {
rule.DoRule(ctx, r, email)
}
}
log.WithContext(ctx).Debugf("开始入库!")

if email == nil {
return nil
Expand Down Expand Up @@ -116,7 +125,7 @@ func (s *Session) Data(r io.Reader) error {
)

if err != nil {
log.Println("mysql insert error:", err.Error())
log.WithContext(ctx).Println("mysql insert error:", err.Error())
}

return nil
Expand Down
57 changes: 57 additions & 0 deletions server/utils/id/logid.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package id

import (
"bytes"
"encoding/hex"
"fmt"
"math/rand"
"net"
"os"
"time"
)

var ip_instance string

func getLocalIP() string {
if ip_instance != "" {
return ip_instance
}

ip := "127.0.0.1"
addrs, err := net.InterfaceAddrs()
if err != nil {
ip_instance = ip
return ip
}
for _, a := range addrs {
if ipnet, ok := a.(*net.IPNet); ok && !ipnet.IP.IsLoopback() {
if ipnet.IP.To4() != nil {
ip = ipnet.IP.String()
break
}
}
}
ip_instance = ip
return ip
}

func GenLogID() string {
r := rand.New(rand.NewSource(time.Now().UnixMicro()))

ip := getLocalIP()

now := time.Now()
timestamp := uint32(now.Unix())
timeNano := now.UnixNano()
pid := os.Getpid()
b := bytes.Buffer{}

b.WriteString(hex.EncodeToString(net.ParseIP(ip).To4()))
b.WriteString(fmt.Sprintf("%x", timestamp&0xffffffff))
b.WriteString(fmt.Sprintf("%04x", timeNano&0xffff))
b.WriteString(fmt.Sprintf("%04x", pid&0xffff))
b.WriteString(fmt.Sprintf("%06x", r.Int31n(1<<24)))
b.WriteString("b0")

return b.String()
}
Loading

0 comments on commit e6ddfba

Please sign in to comment.