Skip to content

Commit

Permalink
feat: 提供更灵活的钩子方法声明
Browse files Browse the repository at this point in the history
  • Loading branch information
yeaha committed May 1, 2024
1 parent 6006fa0 commit 404d560
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 25 deletions.
110 changes: 98 additions & 12 deletions entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,43 @@ type Event int
// Entity 实体对象接口
type Entity interface {
TableName() string
}

// EventHook 事件风格钩子
type EventHook interface {
OnEntityEvent(ctx context.Context, ev Event) error
}

// BeforeInsertHook 在插入前调用
type BeforeInsertHook interface {
BeforeInsert(ctx context.Context) error
}

// AfterInsertHook 在插入后调用
type AfterInsertHook interface {
AfterInsert(ctx context.Context) error
}

// BeforeUpdateHook 在更新前调用
type BeforeUpdateHook interface {
BeforeUpdate(ctx context.Context) error
}

// AfterUpdateHook 在更新后调用
type AfterUpdateHook interface {
AfterUpdate(ctx context.Context) error
}

// BeforeDeleteHook 在删除前调用
type BeforeDeleteHook interface {
BeforeDelete(ctx context.Context) error
}

// AfterDeleteHook 在删除后调用
type AfterDeleteHook interface {
AfterDelete(ctx context.Context) error
}

// Column 字段信息
type Column struct {
StructField string
Expand Down Expand Up @@ -228,7 +262,7 @@ func Insert(ctx context.Context, ent Entity, db DB) (int64, error) {
ctx, cancel := context.WithTimeout(ctx, WriteTimeout)
defer cancel()

if err := ent.OnEntityEvent(ctx, EventBeforeInsert); err != nil {
if err := beforeInsert(ctx, ent); err != nil {
return 0, fmt.Errorf("before insert, %w", err)
}

Expand All @@ -240,10 +274,9 @@ func Insert(ctx context.Context, ent Entity, db DB) (int64, error) {
return 0, err
}

if err := ent.OnEntityEvent(ctx, EventAfterInsert); err != nil {
if err := afterInsert(ctx, ent); err != nil {
return 0, fmt.Errorf("after insert, %w", err)
}

return lastID, nil
}

Expand All @@ -252,7 +285,7 @@ func Update(ctx context.Context, ent Entity, db DB) error {
ctx, cancel := context.WithTimeout(ctx, WriteTimeout)
defer cancel()

if err := ent.OnEntityEvent(ctx, EventBeforeUpdate); err != nil {
if err := beforeUpdate(ctx, ent); err != nil {
return fmt.Errorf("before update, %w", err)
}

Expand All @@ -269,7 +302,7 @@ func Update(ctx context.Context, ent Entity, db DB) error {
}
}

if err := ent.OnEntityEvent(ctx, EventAfterUpdate); err != nil {
if err := afterUpdate(ctx, ent); err != nil {
return fmt.Errorf("after update, %w", err)
}
return nil
Expand All @@ -280,7 +313,7 @@ func Delete(ctx context.Context, ent Entity, db DB) error {
ctx, cancel := context.WithTimeout(ctx, WriteTimeout)
defer cancel()

if err := ent.OnEntityEvent(ctx, EventBeforeDelete); err != nil {
if err := beforeDelete(ctx, ent); err != nil {
return fmt.Errorf("before delete, %w", err)
}

Expand All @@ -294,7 +327,7 @@ func Delete(ctx context.Context, ent Entity, db DB) error {
}
}

if err := ent.OnEntityEvent(ctx, EventAfterDelete); err != nil {
if err := afterDelete(ctx, ent); err != nil {
return fmt.Errorf("after delete, %w", err)
}
return nil
Expand Down Expand Up @@ -337,7 +370,7 @@ func (pis *PrepareInsertStatement) ExecContext(ctx context.Context, ent Entity)
ctx, cancel := context.WithTimeout(ctx, WriteTimeout)
defer cancel()

if err := ent.OnEntityEvent(ctx, EventBeforeInsert); err != nil {
if err := beforeInsert(ctx, ent); err != nil {
return 0, fmt.Errorf("before insert, %w", err)
}

Expand All @@ -349,10 +382,9 @@ func (pis *PrepareInsertStatement) ExecContext(ctx context.Context, ent Entity)
return 0, err
}

if err := ent.OnEntityEvent(ctx, EventAfterInsert); err != nil {
if err := afterInsert(ctx, ent); err != nil {
return 0, fmt.Errorf("after insert, %w", err)
}

return lastID, nil
}

Expand Down Expand Up @@ -415,7 +447,7 @@ func (pus *PrepareUpdateStatement) ExecContext(ctx context.Context, ent Entity)
ctx, cancel := context.WithTimeout(ctx, WriteTimeout)
defer cancel()

if err := ent.OnEntityEvent(ctx, EventBeforeUpdate); err != nil {
if err := beforeUpdate(ctx, ent); err != nil {
return fmt.Errorf("before update, %w", err)
}

Expand All @@ -432,7 +464,7 @@ func (pus *PrepareUpdateStatement) ExecContext(ctx context.Context, ent Entity)
}
}

if err := ent.OnEntityEvent(ctx, EventAfterUpdate); err != nil {
if err := afterUpdate(ctx, ent); err != nil {
return fmt.Errorf("after update, %w", err)
}
return nil
Expand All @@ -455,3 +487,57 @@ func (pus *PrepareUpdateStatement) execContext(ctx context.Context, ent Entity)
}
return nil
}

func beforeInsert(ctx context.Context, ent Entity) error {
if v, ok := ent.(BeforeInsertHook); ok {
return v.BeforeInsert(ctx)
} else if v, ok := ent.(EventHook); ok {
return v.OnEntityEvent(ctx, EventBeforeInsert)
}
return nil
}

func afterInsert(ctx context.Context, ent Entity) error {
if v, ok := ent.(AfterInsertHook); ok {
return v.AfterInsert(ctx)
} else if v, ok := ent.(EventHook); ok {
return v.OnEntityEvent(ctx, EventAfterInsert)
}
return nil
}

func beforeUpdate(ctx context.Context, ent Entity) error {
if v, ok := ent.(BeforeUpdateHook); ok {
return v.BeforeUpdate(ctx)
} else if v, ok := ent.(EventHook); ok {
return v.OnEntityEvent(ctx, EventBeforeUpdate)
}
return nil
}

func afterUpdate(ctx context.Context, ent Entity) error {
if v, ok := ent.(AfterUpdateHook); ok {
return v.AfterUpdate(ctx)
} else if v, ok := ent.(EventHook); ok {
return v.OnEntityEvent(ctx, EventAfterUpdate)
}
return nil
}

func beforeDelete(ctx context.Context, ent Entity) error {
if v, ok := ent.(BeforeDeleteHook); ok {
return v.BeforeDelete(ctx)
} else if v, ok := ent.(EventHook); ok {
return v.OnEntityEvent(ctx, EventBeforeDelete)
}
return nil
}

func afterDelete(ctx context.Context, ent Entity) error {
if v, ok := ent.(AfterDeleteHook); ok {
return v.AfterDelete(ctx)
} else if v, ok := ent.(EventHook); ok {
return v.OnEntityEvent(ctx, EventAfterDelete)
}
return nil
}
13 changes: 0 additions & 13 deletions entity_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package entity

import (
"context"
"reflect"
"testing"
"time"
Expand Down Expand Up @@ -191,10 +190,6 @@ func (ge GenernalEntity) TableName() string {
return "genernal"
}

func (ge GenernalEntity) OnEntityEvent(ctx context.Context, ev Event) error {
return nil
}

type EmptyEntity struct {
ID int `db:"-"`
Name string `db:"-"`
Expand All @@ -204,10 +199,6 @@ func (ee EmptyEntity) TableName() string {
return "emtpy"
}

func (ee *EmptyEntity) OnEntityEvent(ctx context.Context, ev Event) error {
return nil
}

type NoPrimaryKeyEntity struct {
ID int `db:"int"`
Name string `db:"name"`
Expand All @@ -216,7 +207,3 @@ type NoPrimaryKeyEntity struct {
func (npe NoPrimaryKeyEntity) TableName() string {
return "no_primary_key"
}

func (npe *NoPrimaryKeyEntity) OnEntityEvent(ctx context.Context, ev Event) error {
return nil
}

0 comments on commit 404d560

Please sign in to comment.