diff --git a/README.md b/README.md index ff9f715..22b4e95 100644 --- a/README.md +++ b/README.md @@ -33,16 +33,17 @@ If use distributed-lock to implement it. I will depends on the system-time of ea 1. Create redisDriver instance, set the `ServiceName` and initialize `dcron`. The `ServiceName` will defined the same task unit. ```golang - drv, _ := redis.NewDriver(&redis.Options{ - Host: "127.0.0.1:6379" - }) - dcron := NewDcron("server1", drv) +redisCli := redis.NewClient(&redis.Options{ + Addr: DefaultRedisAddr, +}) +drv := driver.NewRedisDriver(redisCli) +dcron := NewDcron("server1", drv) ``` 2. Use cron-language to add task, you should set the `TaskName`, the `TaskName` is the primary-key of each task. ```golang - dcron.AddFunc("test1","*/3 * * * *",func(){ - fmt.Println("execute test1 task",time.Now().Format("15:04:05")) - }) +dcron.AddFunc("test1","*/3 * * * *",func(){ + fmt.Println("execute test1 task",time.Now().Format("15:04:05")) +}) ``` 3. Begin the task ```golang diff --git a/README_CN.md b/README_CN.md index c38f082..15f9544 100644 --- a/README_CN.md +++ b/README_CN.md @@ -32,16 +32,19 @@ a lightweight distributed job scheduler library based on redis or etcd 1.创建redisDriver实例,指定服务名并初始化dcron。服务名为执行相同任务的单元。 ```golang - drv, _ := redis.NewDriver(&redis.Options{ - Host: "127.0.0.1:6379" - }) - dcron := NewDcron("server1", drv) +redisCli := redis.NewClient(&redis.Options{ + Addr: DefaultRedisAddr, +}) +drv := driver.NewRedisDriver(redisCli) +dcron := NewDcron("server1", drv) ``` +当然,如果你可以自己实现一个自定义的Driver也是可以的,只需要实现[DriverV2](driver/driver.go)接口即可。 + 2.使用cron语法添加任务,需要指定任务名。任务名作为任务的唯一标识,必须保证唯一。 ```golang - dcron.AddFunc("test1","*/3 * * * *",func(){ - fmt.Println("执行 test1 任务",time.Now().Format("15:04:05")) - }) +dcron.AddFunc("test1","*/3 * * * *",func(){ + fmt.Println("执行 test1 任务",time.Now().Format("15:04:05")) +}) ``` 3.开始任务。 ```golang diff --git a/dcron.go b/dcron.go index 907e5b5..b5a74de 100644 --- a/dcron.go +++ b/dcron.go @@ -1,6 +1,7 @@ package dcron import ( + "context" "errors" "log" "os" @@ -15,7 +16,7 @@ import ( const ( defaultReplicas = 50 - defaultDuration = time.Second + defaultDuration = 3 * time.Second ) const ( @@ -31,7 +32,7 @@ type Dcron struct { jobsRWMut sync.Mutex ServerName string - nodePool *NodePool + nodePool INodePool running int32 logger dlog.Logger @@ -47,34 +48,24 @@ type Dcron struct { } // NewDcron create a Dcron -func NewDcron(serverName string, driver driver.Driver, cronOpts ...cron.Option) *Dcron { +func NewDcron(serverName string, driver driver.DriverV2, cronOpts ...cron.Option) *Dcron { dcron := newDcron(serverName) dcron.crOptions = cronOpts dcron.cr = cron.New(cronOpts...) dcron.running = dcronStopped - var err error - dcron.nodePool, err = newNodePool(serverName, driver, dcron, dcron.nodeUpdateDuration, dcron.hashReplicas) - if err != nil { - dcron.logger.Errorf("ERR: %s", err.Error()) - return nil - } + dcron.nodePool = NewNodePool(serverName, driver, dcron.nodeUpdateDuration, dcron.hashReplicas, dcron.logger) return dcron } // NewDcronWithOption create a Dcron with Dcron Option -func NewDcronWithOption(serverName string, driver driver.Driver, dcronOpts ...Option) *Dcron { +func NewDcronWithOption(serverName string, driver driver.DriverV2, dcronOpts ...Option) *Dcron { dcron := newDcron(serverName) for _, opt := range dcronOpts { opt(dcron) } dcron.cr = cron.New(dcron.crOptions...) - var err error - dcron.nodePool, err = newNodePool(serverName, driver, dcron, dcron.nodeUpdateDuration, dcron.hashReplicas) - if err != nil { - dcron.logger.Errorf("ERR: %s", err.Error()) - return nil - } + dcron.nodePool = NewNodePool(serverName, driver, dcron.nodeUpdateDuration, dcron.hashReplicas, dcron.logger) return dcron } @@ -112,7 +103,7 @@ func (d *Dcron) AddFunc(jobName, cronStr string, cmd func()) (err error) { return d.addJob(jobName, cronStr, cmd, nil) } func (d *Dcron) addJob(jobName, cronStr string, cmd func(), job Job) (err error) { - d.logger.Infof("addJob '%s' : %s", jobName, cronStr) + d.logger.Infof("addJob '%s' : %s", jobName, cronStr) d.jobsRWMut.Lock() defer d.jobsRWMut.Unlock() @@ -147,13 +138,7 @@ func (d *Dcron) Remove(jobName string) { } func (d *Dcron) allowThisNodeRun(jobName string) bool { - allowRunNode := d.nodePool.PickNodeByJobName(jobName) - d.logger.Infof("job '%s' running in node %s", jobName, allowRunNode) - if allowRunNode == "" { - d.logger.Errorf("node pool is empty") - return false - } - return d.nodePool.NodeID == allowRunNode + return d.nodePool.CheckJobAvailable(jobName) } // Start job @@ -169,7 +154,7 @@ func (d *Dcron) Start() { return } d.cr.Start() - d.logger.Infof("dcron started , nodeID is %s", d.nodePool.NodeID) + d.logger.Infof("dcron started , nodeID is %s", d.nodePool.GetNodeID()) } else { d.logger.Infof("dcron have started") } @@ -177,13 +162,17 @@ func (d *Dcron) Start() { // Run Job func (d *Dcron) Run() { + // recover jobs before starting + if d.RecoverFunc != nil { + d.RecoverFunc(d) + } if atomic.CompareAndSwapInt32(&d.running, dcronStopped, dcronRunning) { if err := d.startNodePool(); err != nil { atomic.StoreInt32(&d.running, dcronStopped) return } - d.logger.Infof("dcron running nodeID is %s", d.nodePool.NodeID) + d.logger.Infof("dcron running nodeID is %s", d.nodePool.GetNodeID()) d.cr.Run() } else { d.logger.Infof("dcron already running") @@ -191,7 +180,7 @@ func (d *Dcron) Run() { } func (d *Dcron) startNodePool() error { - if err := d.nodePool.StartPool(); err != nil { + if err := d.nodePool.Start(context.Background()); err != nil { d.logger.Errorf("dcron start node pool error %+v", err) return err } @@ -201,6 +190,7 @@ func (d *Dcron) startNodePool() error { // Stop job func (d *Dcron) Stop() { tick := time.NewTicker(time.Millisecond) + d.nodePool.Stop(context.Background()) for range tick.C { if atomic.CompareAndSwapInt32(&d.running, dcronRunning, dcronStopped) { d.cr.Stop() diff --git a/dcron_test.go b/dcron_test.go index 8a8acfc..b6e9392 100644 --- a/dcron_test.go +++ b/dcron_test.go @@ -4,14 +4,20 @@ import ( "fmt" "log" "os" + "sync" "testing" "time" "github.com/go-redis/redis/v8" "github.com/libi/dcron" "github.com/libi/dcron/dlog" - RedisDriver "github.com/libi/dcron/driver/redis" + "github.com/libi/dcron/driver" "github.com/robfig/cron/v3" + "github.com/stretchr/testify/require" +) + +const ( + DefaultRedisAddr = "127.0.0.1:6379" ) type TestJob1 struct { @@ -24,92 +30,25 @@ func (t TestJob1) Run() { var testData = make(map[string]struct{}) -func Test(t *testing.T) { - drv, err := RedisDriver.NewDriver(&redis.Options{ - Addr: "127.0.0.1:6379", - }) - - if err != nil { - t.Error(err) - } +func TestMultiNodes(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(3) - go runNode(t, drv) + go runNode(t, wg) // 间隔1秒启动测试节点刷新逻辑 time.Sleep(time.Second) - go runNode(t, drv) - time.Sleep(time.Second * 2) - go runNode(t, drv) - - //add recover - dcron2 := dcron.NewDcron("server2", drv, cron.WithChain(cron.Recover(cron.DefaultLogger))) - dcron2.Start() - dcron2.Stop() - - //panic recover test - err = dcron2.AddFunc("s2 test1", "* * * * *", func() { - panic("panic test") - }) - if err != nil { - t.Fatal("add func error") - } - err = dcron2.AddFunc("s2 test2", "* * * * *", func() { - t.Log("执行 service2 test2 任务", time.Now().Format("15:04:05")) - }) - if err != nil { - t.Fatal("add func error") - } - err = dcron2.AddFunc("s2 test3", "* * * * *", func() { - t.Log("执行 service2 test3 任务", time.Now().Format("15:04:05")) - }) - if err != nil { - t.Fatal("add func error") - } - dcron2.Start() - - // set logger - logger := &dlog.StdLogger{ - Log: log.New(os.Stdout, "[test_s3]", log.LstdFlags), - } - // wrap cron recover - rec := dcron.CronOptionChain(cron.Recover(cron.PrintfLogger(logger))) - - // option test - dcron3 := dcron.NewDcronWithOption("server3", drv, rec, - dcron.WithLogger(logger), - dcron.WithHashReplicas(10), - dcron.WithNodeUpdateDuration(time.Second*10)) - - //panic recover test - err = dcron3.AddFunc("s3 test1", "* * * * *", func() { - t.Log("执行 server3 test1 任务,模拟 panic", time.Now().Format("15:04:05")) - panic("panic test") - }) - if err != nil { - t.Fatal("add func error") - } - - err = dcron3.AddFunc("s3 test2", "* * * * *", func() { - t.Log("执行 server3 test2 任务", time.Now().Format("15:04:05")) - }) - if err != nil { - t.Fatal("add func error") - } - err = dcron3.AddFunc("s3 test3", "* * * * *", func() { - t.Log("执行 server3 test3 任务", time.Now().Format("15:04:05")) - }) - if err != nil { - t.Fatal("add func error") - } - dcron3.Start() + go runNode(t, wg) + time.Sleep(time.Second) + go runNode(t, wg) - //测试120秒后退出 - time.Sleep(120 * time.Second) - t.Log("testData", testData) - dcron2.Stop() - dcron3.Stop() + wg.Wait() } -func runNode(t *testing.T, drv *RedisDriver.RedisDriver) { +func runNode(t *testing.T, wg *sync.WaitGroup) { + redisCli := redis.NewClient(&redis.Options{ + Addr: DefaultRedisAddr, + }) + drv := driver.NewRedisDriver(redisCli) dcron := dcron.NewDcron("server1", drv) //添加多个任务 启动多个节点时 任务会均匀分配给各个节点 @@ -147,17 +86,18 @@ func runNode(t *testing.T, drv *RedisDriver.RedisDriver) { //移除测试 dcron.Remove("s1 test3") + <-time.After(120 * time.Second) + wg.Done() + dcron.Stop() } func Test_SecondsJob(t *testing.T) { - drv, err := RedisDriver.NewDriver(&redis.Options{ - Addr: "127.0.0.1:6379", + redisCli := redis.NewClient(&redis.Options{ + Addr: DefaultRedisAddr, }) - if err != nil { - t.Error(err) - } + drv := driver.NewRedisDriver(redisCli) dcr := dcron.NewDcronWithOption(t.Name(), drv, dcron.CronOptionSeconds()) - err = dcr.AddFunc("job1", "*/5 * * * * *", func() { + err := dcr.AddFunc("job1", "*/5 * * * * *", func() { t.Log(time.Now()) }) if err != nil { @@ -167,3 +107,55 @@ func Test_SecondsJob(t *testing.T) { time.Sleep(15 * time.Second) dcr.Stop() } + +func runSecondNode(id string, wg *sync.WaitGroup, runningTime time.Duration, t *testing.T) { + redisCli := redis.NewClient(&redis.Options{ + Addr: DefaultRedisAddr, + }) + drv := driver.NewRedisDriver(redisCli) + dcr := dcron.NewDcronWithOption(t.Name(), drv, + dcron.CronOptionSeconds(), + dcron.WithLogger(&dlog.StdLogger{ + Log: log.New(os.Stdout, "["+id+"]", log.LstdFlags), + }), + dcron.CronOptionChain(cron.Recover( + cron.DefaultLogger, + )), + ) + var err error + err = dcr.AddFunc("job1", "*/5 * * * * *", func() { + t.Log(time.Now()) + }) + require.Nil(t, err) + err = dcr.AddFunc("job2", "*/8 * * * * *", func() { + panic("test panic") + }) + require.Nil(t, err) + err = dcr.AddFunc("job3", "*/2 * * * * *", func() { + t.Log("job3:", time.Now()) + }) + require.Nil(t, err) + dcr.Start() + <-time.After(runningTime) + dcr.Stop() + wg.Done() +} + +func Test_SecondJobWithPanicAndMultiNodes(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(5) + go runSecondNode("1", wg, 45*time.Second, t) + go runSecondNode("2", wg, 45*time.Second, t) + go runSecondNode("3", wg, 45*time.Second, t) + go runSecondNode("4", wg, 45*time.Second, t) + go runSecondNode("5", wg, 45*time.Second, t) + wg.Wait() +} + +func Test_SecondJobWithStopAndSwapNode(t *testing.T) { + wg := &sync.WaitGroup{} + wg.Add(2) + go runSecondNode("1", wg, 60*time.Second, t) + go runSecondNode("2", wg, 20*time.Second, t) + wg.Wait() +} diff --git a/dlog/logger.go b/dlog/logger.go index 98857f1..168b61a 100644 --- a/dlog/logger.go +++ b/dlog/logger.go @@ -1,11 +1,17 @@ package dlog -import "log" +import ( + "testing" +) type PrintfLogger interface { Printf(string, ...interface{}) } +type LogfLogger interface { + Logf(string, ...interface{}) +} + type Logger interface { PrintfLogger Infof(string, ...interface{}) @@ -14,7 +20,7 @@ type Logger interface { } type StdLogger struct { - Log *log.Logger + Log PrintfLogger } func (l *StdLogger) Infof(format string, args ...interface{}) { @@ -32,3 +38,21 @@ func (l *StdLogger) Errorf(format string, args ...interface{}) { func (l *StdLogger) Printf(format string, args ...interface{}) { l.Log.Printf(format, args...) } + +type PrintfLoggerFromLogfLogger struct { + Log LogfLogger +} + +func (l *PrintfLoggerFromLogfLogger) Printf(fmt string, args ...interface{}) { + l.Log.Logf(fmt, args) +} + +func NewPrintfLoggerFromLogfLogger(logger LogfLogger) PrintfLogger { + return &PrintfLoggerFromLogfLogger{Log: logger} +} + +func NewLoggerForTest(t *testing.T) Logger { + return &StdLogger{ + Log: NewPrintfLoggerFromLogfLogger(t), + } +} diff --git a/driver/driver.go b/driver/driver.go index 8b44ce7..a01c1aa 100644 --- a/driver/driver.go +++ b/driver/driver.go @@ -1,18 +1,37 @@ package driver import ( - "time" + "context" - "github.com/libi/dcron/dlog" + "github.com/go-redis/redis/v8" + clientv3 "go.etcd.io/etcd/client/v3" ) -//Driver is a driver interface -type Driver interface { - // Ping is check dirver is valid - Ping() error - SetLogger(log dlog.Logger) - SetHeartBeat(nodeID string) - SetTimeout(timeout time.Duration) - GetServiceNodeList(ServiceName string) ([]string, error) - RegisterServiceNode(ServiceName string) (string, error) +// There is only one driver for one dcron. +// Tips for write a user-defined Driver by yourself. +// 1. Confirm that `Stop` and `Start` can be called for more times. +// 2. Must make `GetNodes` will return error when timeout. +type DriverV2 interface { + // init driver + Init(serviceName string, opts ...Option) + // get nodeID + NodeID() string + // get nodes + GetNodes(ctx context.Context) (nodes []string, err error) + Start(ctx context.Context) (err error) + Stop(ctx context.Context) (err error) + + withOption(opt Option) (err error) +} + +func NewRedisDriver(redisClient *redis.Client) DriverV2 { + return newRedisDriver(redisClient) +} + +func NewEtcdDriver(etcdCli *clientv3.Client) DriverV2 { + return newEtcdDriver(etcdCli) +} + +func NewRedisZSetDriver(redisClient *redis.Client) DriverV2 { + return newRedisZSetDriver(redisClient) } diff --git a/driver/driver_test.go b/driver/driver_test.go deleted file mode 100644 index 380c085..0000000 --- a/driver/driver_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package driver_test - -import ( - "flag" - "testing" - "time" - - "github.com/go-redis/redis/v8" - DcronDriver "github.com/libi/dcron/driver" - RedisDriver "github.com/libi/dcron/driver/redis" - "github.com/stretchr/testify/require" -) - -var ( - redisAddr = flag.String("rAddr", "127.0.0.1:6379", "redis serve addr") - password = flag.String("password", "", "redis password") -) - -// you should run this test when the redis is served. -// you can run test like below command. -// go test -v --rAddr 127.0.0.1:6379 -password 123456 -// rAddr is the redis serve addr - -func TestRedisDriver(t *testing.T) { - t.Logf("test redis serve on %s", *redisAddr) - - serviceName := t.Name() - NewDriverFunc := func(_ int) (DcronDriver.Driver, error) { - driver, err := RedisDriver.NewDriver(&redis.Options{ - Addr: *redisAddr, - Password: *password, - }) - require.Nil(t, err) - require.Nil(t, driver.Ping()) - driver.SetTimeout(5 * time.Second) - nodeId, err := driver.RegisterServiceNode(serviceName) - require.Nil(t, err) - driver.SetHeartBeat(nodeId) - return driver, nil - } - n := 10 - drivers := make([]DcronDriver.Driver, 0) - for i := 0; i < n; i++ { - dr, err := NewDriverFunc(i) - require.Nilf(t, err, "new driver error %d", i) - drivers = append(drivers, dr) - } - - for i := 0; i < n; i++ { - nodeIds, err := drivers[i].GetServiceNodeList(serviceName) - require.Nilf(t, err, "get service nodelist error %d", i) - require.Equal(t, n, len(nodeIds)) - } -} diff --git a/driver/etcd/etcd_driver.go b/driver/etcd/etcd_driver.go deleted file mode 100644 index 4b3fb03..0000000 --- a/driver/etcd/etcd_driver.go +++ /dev/null @@ -1,230 +0,0 @@ -package etcd - -import ( - "context" - "log" - "sync" - "time" - - "github.com/google/uuid" - "github.com/libi/dcron/dlog" - "github.com/libi/dcron/driver" - "go.etcd.io/etcd/api/v3/mvccpb" - clientv3 "go.etcd.io/etcd/client/v3" -) - -var _ driver.Driver = &EtcdDriver{} - -const ( - defaultLease = 5 // 5 second ttl - dialTimeout = 3 * time.Second - businessTimeout = 5 * time.Second -) - -type EtcdDriver struct { - cli *clientv3.Client - lease int64 - serverList map[string]map[string]string - lock sync.RWMutex - leaseID clientv3.LeaseID - logger dlog.Logger -} - -//NewEtcdDriver ... -func NewEtcdDriver(config *clientv3.Config) (*EtcdDriver, error) { - cli, err := clientv3.New(*config) - if err != nil { - return nil, err - } - - ser := &EtcdDriver{ - cli: cli, - serverList: make(map[string]map[string]string, 10), - logger: &dlog.StdLogger{ - Log: log.Default(), - }, - } - - return ser, nil -} - -//设置key value,绑定租约 -func (s *EtcdDriver) putKeyWithLease(key, val string) (clientv3.LeaseID, error) { - //设置租约时间,最少5s - if s.lease < defaultLease { - s.lease = defaultLease - } - - ctx, cancel := context.WithTimeout(context.Background(), businessTimeout) - defer cancel() - - resp, err := s.cli.Grant(ctx, s.lease) - if err != nil { - return 0, err - } - //注册服务并绑定租约 - _, err = s.cli.Put(ctx, key, val, clientv3.WithLease(resp.ID)) - if err != nil { - return 0, err - } - - return resp.ID, nil -} - -func (s *EtcdDriver) randNodeID(serviceName string) (nodeID string) { - return getPrefix(serviceName) + uuid.New().String() -} - -//WatchService 初始化服务列表和监视 -func (s *EtcdDriver) watchService(serviceName string) error { - prefix := getPrefix(serviceName) - // 根据前缀获取现有的key - resp, err := s.cli.Get(context.Background(), prefix, clientv3.WithPrefix()) - if err != nil { - return err - } - - for _, ev := range resp.Kvs { - s.setServiceList(serviceName, string(ev.Key), string(ev.Value)) - } - - // 监视前缀,修改变更的server - go s.watcher(serviceName) - return nil -} - -func getPrefix(serviceName string) string { - return serviceName + "/" -} - -// watcher 监听前缀 -func (s *EtcdDriver) watcher(serviceName string) { - prefix := getPrefix(serviceName) - rch := s.cli.Watch(context.Background(), prefix, clientv3.WithPrefix()) - for wresp := range rch { - for _, ev := range wresp.Events { - switch ev.Type { - case mvccpb.PUT: //修改或者新增 - s.setServiceList(serviceName, string(ev.Kv.Key), string(ev.Kv.Value)) - case mvccpb.DELETE: //删除 - s.delServiceList(serviceName, string(ev.Kv.Key)) - } - } - } -} - -// setServiceList 新增服务地址 -func (s *EtcdDriver) setServiceList(serviceName, key, val string) { - s.lock.Lock() - defer s.lock.Unlock() - if _, ok := s.serverList[serviceName]; !ok { - nodeMap := map[string]string{ - key: val, - } - s.serverList[serviceName] = nodeMap - } else { - s.serverList[serviceName][key] = val - } -} - -// DelServiceList 删除服务地址 -func (s *EtcdDriver) delServiceList(serviceName, key string) { - s.lock.Lock() - defer s.lock.Unlock() - if nodeMap, ok := s.serverList[serviceName]; ok { - delete(nodeMap, key) - } -} - -// GetServices 获取服务地址 -func (s *EtcdDriver) getServices(serviceName string) []string { - s.lock.RLock() - defer s.lock.RUnlock() - addrs := make([]string, 0) - if nodeMap, ok := s.serverList[serviceName]; ok { - for _, v := range nodeMap { - addrs = append(addrs, v) - } - } - return addrs -} - -func (e *EtcdDriver) Ping() error { - return nil -} - -func (e *EtcdDriver) keepAlive(ctx context.Context, nodeID string) (<-chan *clientv3.LeaseKeepAliveResponse, error) { - var err error - e.leaseID, err = e.putKeyWithLease(nodeID, nodeID) - if err != nil { - e.logger.Errorf("putKeyWithLease error: %v", err) - return nil, err - } - - return e.cli.KeepAlive(ctx, e.leaseID) -} - -func (e *EtcdDriver) revoke() { - _, err := e.cli.Lease.Revoke(context.Background(), e.leaseID) - if err != nil { - log.Printf("lease revoke error: %v", err) - } -} - -func (e *EtcdDriver) SetHeartBeat(nodeID string) { - leaseCh, err := e.keepAlive(context.Background(), nodeID) - if err != nil { - e.logger.Errorf("setHeartBeat error: %v", err) - return - } - go func() { - defer func() { - err := recover() - if err != nil { - e.logger.Errorf("keepAlive panic: %v", err) - return - } - }() - for { - select { - case _, ok := <-leaseCh: - if !ok { - e.revoke() - e.SetHeartBeat(nodeID) - return - } - case <-time.After(businessTimeout): - e.logger.Errorf("ectd cli keepalive timeout") - return - } - } - }() -} - -func (e *EtcdDriver) SetLogger(log dlog.Logger) { - e.logger = log -} - -// SetTimeout set etcd lease timeout -func (e *EtcdDriver) SetTimeout(timeout time.Duration) { - e.lease = int64(timeout.Seconds()) -} - -// GetServiceNodeList get service notes -func (e *EtcdDriver) GetServiceNodeList(serviceName string) ([]string, error) { - return e.getServices(serviceName), nil -} - -// RegisterServiceNode register a node to service -func (e *EtcdDriver) RegisterServiceNode(serviceName string) (string, error) { - nodeId := e.randNodeID(serviceName) - _, err := e.putKeyWithLease(nodeId, nodeId) - if err != nil { - return "", err - } - err = e.watchService(serviceName) - if err != nil { - return "", err - } - return nodeId, nil -} diff --git a/driver/etcd/etcd_driver_test.go b/driver/etcd/etcd_driver_test.go deleted file mode 100644 index d7c6083..0000000 --- a/driver/etcd/etcd_driver_test.go +++ /dev/null @@ -1,92 +0,0 @@ -package etcd - -import ( - "testing" - "time" - - "github.com/stretchr/testify/require" - clientv3 "go.etcd.io/etcd/client/v3" - "go.etcd.io/etcd/tests/v3/integration" -) - -func TestEtcdDriver(t *testing.T) { - var lazyCluster = integration.NewLazyCluster() - defer lazyCluster.Terminate() - - ed, err := NewEtcdDriver(&clientv3.Config{ - Endpoints: lazyCluster.EndpointsV3(), - DialTimeout: dialTimeout, - }) - - require.Nil(t, err) - serviceName := "testService" - - nodeMap := make(map[string]string) - - count := 10 - - for i := 0; i < count; i++ { - nodeID, err := ed.RegisterServiceNode(serviceName) - require.Nil(t, err) - t.Logf("nodeId %v:%v", i, nodeID) - nodeMap[nodeID] = nodeID - } - - list, err := ed.GetServiceNodeList(serviceName) - - require.Nil(t, err) - - require.Equal(t, count, len(list)) - - for _, v := range list { - if _, ok := nodeMap[v]; !ok { - t.Errorf("nodeId %v not found!!!", v) - } - } - -} - -func TestSetHeartBeat(t *testing.T) { - - var lazyCluster = integration.NewLazyCluster() - defer lazyCluster.Terminate() - - ed, err := NewEtcdDriver(&clientv3.Config{ - Endpoints: lazyCluster.EndpointsV3(), - DialTimeout: dialTimeout, - }) - - require.Nil(t, err) - serviceName := "testService" - - nodeMap := make(map[string]string) - - count := 10 - - //一半设置心跳 - for i := 0; i < count; i++ { - nodeID, err := ed.RegisterServiceNode(serviceName) - require.Nil(t, err) - t.Logf("nodeId %v:%v", i, nodeID) - if i%2 == 0 { - ed.SetHeartBeat(nodeID) - nodeMap[nodeID] = nodeID - } - } - - time.Sleep(time.Second * 10) - - //10s后获取serverList,预期只能取到一半 - list, err := ed.GetServiceNodeList(serviceName) - - require.Nil(t, err) - - require.Equal(t, len(nodeMap), len(list)) - - for _, v := range list { - if _, ok := nodeMap[v]; !ok { - t.Errorf("nodeId %v not found!!!", v) - } - } - -} diff --git a/driver/etcddriver.go b/driver/etcddriver.go new file mode 100644 index 0000000..b489c8b --- /dev/null +++ b/driver/etcddriver.go @@ -0,0 +1,219 @@ +package driver + +import ( + "context" + "log" + "sync" + "time" + + "github.com/libi/dcron/dlog" + "go.etcd.io/etcd/api/v3/mvccpb" + clientv3 "go.etcd.io/etcd/client/v3" +) + +const ( + etcdDefaultLease = 5 // min lease time + etcdDialTimeout = 3 * time.Second + etcdBusinessTimeout = 5 * time.Second +) + +type EtcdDriver struct { + nodeID string + serviceName string + + cli *clientv3.Client + lease int64 + nodes *sync.Map + leaseID clientv3.LeaseID + logger dlog.Logger + + ctx context.Context + cancel context.CancelFunc +} + +// NewEtcdDriver +func newEtcdDriver(cli *clientv3.Client) *EtcdDriver { + ser := &EtcdDriver{ + cli: cli, + nodes: &sync.Map{}, + logger: &dlog.StdLogger{ + Log: log.Default(), + }, + } + + return ser +} + +// 设置key value,绑定租约 +func (e *EtcdDriver) putKeyWithLease(ctx context.Context, key, val string) (clientv3.LeaseID, error) { + //设置租约时间,最少5s + if e.lease < etcdDefaultLease { + e.lease = etcdDefaultLease + } + + subCtx, cancel := context.WithTimeout(ctx, etcdBusinessTimeout) + defer cancel() + resp, err := e.cli.Grant(subCtx, e.lease) + if err != nil { + return 0, err + } + //注册服务并绑定租约 + _, err = e.cli.Put(subCtx, key, val, clientv3.WithLease(resp.ID)) + if err != nil { + return 0, err + } + + return resp.ID, nil +} + +// WatchService 初始化服务列表和监视 +func (e *EtcdDriver) watchService(ctx context.Context, serviceName string) error { + prefix := GetKeyPre(serviceName) + // 根据前缀获取现有的key + resp, err := e.cli.Get(ctx, prefix, clientv3.WithPrefix()) + if err != nil { + return err + } + + for _, ev := range resp.Kvs { + e.setServiceList(string(ev.Key), string(ev.Value)) + } + + // 监视前缀,修改变更的server + go e.watcher(serviceName) + return nil +} + +// watcher 监听前缀 +func (e *EtcdDriver) watcher(serviceName string) { + prefix := GetKeyPre(serviceName) + rch := e.cli.Watch(context.Background(), prefix, clientv3.WithPrefix()) + for wresp := range rch { + for _, ev := range wresp.Events { + switch ev.Type { + case mvccpb.PUT: //修改或者新增 + e.setServiceList(string(ev.Kv.Key), string(ev.Kv.Value)) + case mvccpb.DELETE: //删除 + e.delServiceList(string(ev.Kv.Key)) + } + } + } +} + +// setServiceList 新增服务地址 +func (e *EtcdDriver) setServiceList(key, val string) { + e.nodes.Store(key, val) +} + +// DelServiceList 删除服务地址 +func (e *EtcdDriver) delServiceList(key string) { + e.nodes.Delete(key) +} + +// GetServices 获取服务地址 +func (e *EtcdDriver) getServices() []string { + addrs := make([]string, 0) + e.nodes.Range(func(key, _ interface{}) bool { + addrs = append(addrs, key.(string)) + return true + }) + return addrs +} + +func (e *EtcdDriver) keepAlive(ctx context.Context, nodeID string) (<-chan *clientv3.LeaseKeepAliveResponse, error) { + var err error + e.leaseID, err = e.putKeyWithLease(ctx, nodeID, nodeID) + if err != nil { + e.logger.Errorf("putKeyWithLease error: %v", err) + return nil, err + } + + return e.cli.KeepAlive(ctx, e.leaseID) +} + +func (e *EtcdDriver) revoke(ctx context.Context) { + _, err := e.cli.Lease.Revoke(ctx, e.leaseID) + if err != nil { + e.logger.Printf("lease revoke error: %v", err) + } +} + +func (e *EtcdDriver) heartBeat(ctx context.Context) { +label: + leaseCh, err := e.keepAlive(ctx, e.nodeID) + if err != nil { + return + } + for { + select { + case <-e.ctx.Done(): + { + e.logger.Infof("driver stopped") + return + } + case _, ok := <-leaseCh: + { + // if lease timeout, goto top of + // this function to keepalive + if !ok { + goto label + } + } + case <-time.After(etcdBusinessTimeout): + { + e.logger.Errorf("ectd cli keepalive timeout") + return + } + case <-time.After(time.Duration(e.lease/2) * (time.Second)): + { + // if near to nodes time, + // renew the lease + goto label + } + } + } +} + +func (e *EtcdDriver) Init(serverName string, opts ...Option) { + e.serviceName = serverName + e.nodeID = GetNodeId(serverName) +} + +func (e *EtcdDriver) NodeID() string { + return e.nodeID +} + +func (e *EtcdDriver) GetNodes(ctx context.Context) (nodes []string, err error) { + return e.getServices(), nil +} + +func (e *EtcdDriver) Start(ctx context.Context) (err error) { + // renew a global ctx when start every time + e.ctx, e.cancel = context.WithCancel(context.TODO()) + go e.heartBeat(ctx) + err = e.watchService(ctx, e.serviceName) + if err != nil { + return + } + return nil +} + +func (e *EtcdDriver) Stop(ctx context.Context) (err error) { + e.revoke(ctx) + e.cancel() + return +} + +func (e *EtcdDriver) withOption(opt Option) (err error) { + switch opt.Type() { + case OptionTypeTimeout: + { + e.lease = int64(opt.(TimeoutOption).timeout.Seconds()) + } + case OptionTypeLogger: + { + e.logger = opt.(LoggerOption).logger + } + } + return +} diff --git a/driver/etcddriver_test.go b/driver/etcddriver_test.go new file mode 100644 index 0000000..e0f8225 --- /dev/null +++ b/driver/etcddriver_test.go @@ -0,0 +1,96 @@ +package driver_test + +import ( + "context" + "testing" + "time" + + "github.com/libi/dcron/dlog" + "github.com/libi/dcron/driver" + "github.com/stretchr/testify/require" + clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/tests/v3/integration" +) + +func testFuncNewEtcdDriver(cfg clientv3.Config) driver.DriverV2 { + cli, err := clientv3.New(cfg) + if err != nil { + panic(err) + } + return driver.NewEtcdDriver(cli) +} + +func TestEtcdDriver_GetNodes(t *testing.T) { + etcdsvr := integration.NewLazyCluster() + defer etcdsvr.Terminate() + N := 10 + drvs := make([]driver.DriverV2, 0) + for i := 0; i < N; i++ { + drv := testFuncNewEtcdDriver(clientv3.Config{ + Endpoints: etcdsvr.EndpointsV3(), + DialTimeout: 3 * time.Second, + }) + drv.Init(t.Name(), driver.NewTimeoutOption(5*time.Second), driver.NewLoggerOption(dlog.NewLoggerForTest(t))) + err := drv.Start(context.Background()) + require.Nil(t, err) + drvs = append(drvs, drv) + } + <-time.After(5 * time.Second) + for _, v := range drvs { + nodes, err := v.GetNodes(context.Background()) + require.Nil(t, err) + require.Equal(t, N, len(nodes)) + } + + for _, v := range drvs { + v.Stop(context.Background()) + } +} + +func TestEtcdDriver_Stop(t *testing.T) { + var err error + var nodes []string + etcdsvr := integration.NewLazyCluster() + defer etcdsvr.Terminate() + + drv1 := testFuncNewEtcdDriver(clientv3.Config{ + Endpoints: etcdsvr.EndpointsV3(), + DialTimeout: 3 * time.Second, + }) + drv1.Init(t.Name(), driver.NewTimeoutOption(5*time.Second), driver.NewLoggerOption(dlog.NewLoggerForTest(t))) + + drv2 := testFuncNewEtcdDriver(clientv3.Config{ + Endpoints: etcdsvr.EndpointsV3(), + DialTimeout: 3 * time.Second, + }) + drv2.Init(t.Name(), driver.NewTimeoutOption(5*time.Second), driver.NewLoggerOption(dlog.NewLoggerForTest(t))) + err = drv2.Start(context.Background()) + require.Nil(t, err) + + err = drv1.Start(context.Background()) + require.Nil(t, err) + <-time.After(3 * time.Second) + nodes, err = drv1.GetNodes(context.Background()) + require.Nil(t, err) + require.Len(t, nodes, 2) + + nodes, err = drv2.GetNodes(context.Background()) + require.Nil(t, err) + require.Len(t, nodes, 2) + + drv1.Stop(context.Background()) + + <-time.After(5 * time.Second) + nodes, err = drv2.GetNodes(context.Background()) + require.Nil(t, err) + require.Len(t, nodes, 1) + + err = drv1.Start(context.Background()) + require.Nil(t, err) + <-time.After(5 * time.Second) + nodes, err = drv2.GetNodes(context.Background()) + require.Nil(t, err) + require.Len(t, nodes, 2) + + drv2.Stop(context.Background()) +} diff --git a/driver/option.go b/driver/option.go new file mode 100644 index 0000000..a2a15d1 --- /dev/null +++ b/driver/option.go @@ -0,0 +1,26 @@ +package driver + +import ( + "time" + + "github.com/libi/dcron/dlog" +) + +const ( + OptionTypeTimeout = 0x600 + OptionTypeLogger = 0x601 +) + +type Option interface { + Type() int +} + +type TimeoutOption struct{ timeout time.Duration } + +func (to TimeoutOption) Type() int { return OptionTypeTimeout } +func NewTimeoutOption(timeout time.Duration) TimeoutOption { return TimeoutOption{timeout: timeout} } + +type LoggerOption struct{ logger dlog.Logger } + +func (to LoggerOption) Type() int { return OptionTypeLogger } +func NewLoggerOption(logger dlog.Logger) LoggerOption { return LoggerOption{logger: logger} } diff --git a/driver/redis/redis_driver.go b/driver/redis/redis_driver.go deleted file mode 100644 index 2c88a9b..0000000 --- a/driver/redis/redis_driver.go +++ /dev/null @@ -1,107 +0,0 @@ -package redis - -import ( - "context" - "fmt" - "log" - "time" - - "github.com/go-redis/redis/v8" - "github.com/libi/dcron/dlog" - "github.com/libi/dcron/driver" -) - -// RedisDriver is redisDriver -type RedisDriver struct { - client *redis.Client - timeout time.Duration - Key string - logger dlog.Logger -} - -// NewDriver return a redis driver -func NewDriver(opts *redis.Options) (*RedisDriver, error) { - return &RedisDriver{ - client: redis.NewClient(opts), - logger: &dlog.StdLogger{ - Log: log.Default(), - }, - }, nil -} - -// Ping is check redis valid -func (rd *RedisDriver) Ping() error { - reply, err := rd.client.Ping(context.Background()).Result() - if err != nil { - return err - } - if reply != "PONG" { - return fmt.Errorf("Ping received is error, %s", string(reply)) - } - return err -} - -//SetTimeout set redis timeout -func (rd *RedisDriver) SetTimeout(timeout time.Duration) { - rd.timeout = timeout -} - -//SetHeartBeat set heatbeat -func (rd *RedisDriver) SetHeartBeat(nodeID string) { - go rd.heartBeat(nodeID) -} -func (rd *RedisDriver) heartBeat(nodeID string) { - - //每间隔timeout/2设置一次key的超时时间为timeout - key := nodeID - tickers := time.NewTicker(rd.timeout / 2) - for range tickers.C { - keyExist, err := rd.client.Expire(context.Background(), key, rd.timeout).Result() - if err != nil { - rd.logger.Errorf("redis expire error %+v", err) - continue - } - if !keyExist { - if err := rd.registerServiceNode(nodeID); err != nil { - rd.logger.Errorf("register service node error %+v", err) - } - } - } -} - -func (rd *RedisDriver) SetLogger(log dlog.Logger) { - rd.logger = log -} - -//GetServiceNodeList get a serveice node list -func (rd *RedisDriver) GetServiceNodeList(serviceName string) ([]string, error) { - mathStr := fmt.Sprintf("%s*", driver.GetKeyPre(serviceName)) - return rd.scan(mathStr) -} - -//RegisterServiceNode register a service node -func (rd *RedisDriver) RegisterServiceNode(serviceName string) (nodeID string, err error) { - nodeID = driver.GetNodeId(serviceName) - if err := rd.registerServiceNode(nodeID); err != nil { - return "", err - } - return nodeID, nil -} - -func (rd *RedisDriver) registerServiceNode(nodeID string) error { - return rd.client.SetEX(context.Background(), nodeID, nodeID, rd.timeout).Err() -} - -func (rd *RedisDriver) scan(matchStr string) ([]string, error) { - ret := make([]string, 0) - ctx := context.Background() - iter := rd.client.Scan(ctx, 0, matchStr, -1).Iterator() - for iter.Next(ctx) { - err := iter.Err() - if err != nil { - return nil, err - } - ret = append(ret, iter.Val()) - } - return ret, nil -} diff --git a/driver/redis/redis_driver_test.go b/driver/redis/redis_driver_test.go deleted file mode 100644 index 757d675..0000000 --- a/driver/redis/redis_driver_test.go +++ /dev/null @@ -1,26 +0,0 @@ -package redis - -import ( - "testing" - - "github.com/go-redis/redis/v8" -) - -func TestRedisDriver_Scan(t *testing.T) { - rd, err := NewDriver(&redis.Options{ - Addr: "127.0.0.1:6379", - }) - if err != nil { - return - } - testStr := []string{ - "*", "-----", "", "!@#$%^", "1", "false", - } - for _, str := range testStr { - ret, err := rd.scan(str) - if err != nil { - t.Error(err) - } - t.Log(ret) - } -} diff --git a/driver/redis_cluster/redis_cluster.go b/driver/redis_cluster/redis_cluster.go deleted file mode 100644 index 00bbcaf..0000000 --- a/driver/redis_cluster/redis_cluster.go +++ /dev/null @@ -1,156 +0,0 @@ -package redis_cluster - -import ( - "context" - "crypto/tls" - "fmt" - "log" - "sync" - "time" - - "github.com/go-redis/redis/v8" - "github.com/libi/dcron/dlog" - "github.com/libi/dcron/driver" -) - -// Conf is redis cluster client config -type Conf struct { - Proto string - - // first use addr - Addrs []string - Password string - - MaxRedirects int - ReadOnly bool - - TLSConfig *tls.Config -} - -// RedisClusterDriver is -type RedisClusterDriver struct { - conf *Conf - redisClient *redis.ClusterClient - timeout time.Duration - Key string - ctx context.Context - logger dlog.Logger -} - -// NewDriver return a redis driver -func NewDriver(conf *Conf) (*RedisClusterDriver, error) { - opts := &redis.ClusterOptions{ - Addrs: conf.Addrs, - Password: conf.Password, - ReadOnly: conf.ReadOnly, - } - if conf.MaxRedirects > 0 { - opts.MaxRedirects = conf.MaxRedirects - } - if conf.TLSConfig != nil { - opts.TLSConfig = conf.TLSConfig - } - redisClient := redis.NewClusterClient(opts) - return &RedisClusterDriver{ - conf: conf, - redisClient: redisClient, - ctx: context.TODO(), - logger: &dlog.StdLogger{ - Log: log.Default(), - }, - }, nil -} - -// Ping to check redis cluster is valid or not -func (rd *RedisClusterDriver) Ping() error { - if err := rd.redisClient.Ping(rd.ctx).Err(); err != nil { - return err - } - return nil -} - -//SetTimeout set redis key expiration timeout -func (rd *RedisClusterDriver) SetTimeout(timeout time.Duration) { - rd.timeout = timeout -} - -//SetHeartBeat set heartbeat -func (rd *RedisClusterDriver) SetHeartBeat(nodeID string) { - go rd.heartBeat(nodeID) -} -func (rd *RedisClusterDriver) heartBeat(nodeID string) { - //每间隔timeout/2设置一次key的超时时间为timeout - key := nodeID - tickers := time.NewTicker(rd.timeout / 2) - for range tickers.C { - if err := rd.redisClient.Expire(rd.ctx, key, rd.timeout).Err(); err != nil { - rd.logger.Errorf("redis expire error %+v", err) - continue - } - } -} - -//GetServiceNodeList get a service node list on redis cluster -func (rd *RedisClusterDriver) GetServiceNodeList(serviceName string) ([]string, error) { - mathStr := fmt.Sprintf("%s*", driver.GetKeyPre(serviceName)) - return rd.scan(mathStr) -} - -//RegisterServiceNode register a service node -func (rd *RedisClusterDriver) RegisterServiceNode(serviceName string) (nodeID string, err error) { - nodeID = driver.GetNodeId(serviceName) - key := driver.GetKeyPre(serviceName) + nodeID - if err := rd.redisClient.Set(rd.ctx, key, nodeID, rd.timeout).Err(); err != nil { - return "", err - } - return key, nil -} - -func (rd *RedisClusterDriver) SetLogger(log dlog.Logger) { - rd.logger = log -} - -/** -集群模式下,scan命令只能在单机上执行,因此需要遍历master节点进行合并 -*/ -func (rd *RedisClusterDriver) scan(matchStr string) ([]string, error) { - l := newSyncList() - // scan不能直接执行,只能在每个master节点上上逐个执行再合并 - if err := rd.redisClient.ForEachMaster(rd.ctx, func(ctx context.Context, master *redis.Client) error { - iter := master.Scan(ctx, 0, matchStr, -1).Iterator() - for iter.Next(rd.ctx) { - err := iter.Err() - if err != nil { - return err - } - l.Append(iter.Val()) - } - return nil - }); err != nil { - return l.Values(), err - } - return l.Values(), nil -} - -type syncList struct { - sync.RWMutex - arr []string -} - -func newSyncList() *syncList { - l := new(syncList) - l.arr = make([]string, 0) - return l -} - -func (l *syncList) Append(val string) { - l.Lock() - defer l.Unlock() - l.arr = append(l.arr, val) -} - -func (l *syncList) Values() []string { - l.RLock() - defer l.RUnlock() - return l.arr -} diff --git a/driver/redis_cluster/redis_cluster_test.go b/driver/redis_cluster/redis_cluster_test.go deleted file mode 100644 index ca2a94b..0000000 --- a/driver/redis_cluster/redis_cluster_test.go +++ /dev/null @@ -1,22 +0,0 @@ -package redis_cluster - -import ( - "testing" - - "github.com/libi/dcron/driver" -) - -func TestClusterScan(t *testing.T) { - rd, err := NewDriver(&Conf{ - Addrs: []string{"127.0.0.1:6379"}, - }) - if err != nil { - return - } - matchStr := driver.GetKeyPre("service") - ret, err := rd.scan(matchStr) - if err != nil { - t.Log(err) - } - t.Log(ret) -} diff --git a/driver/redisdriver.go b/driver/redisdriver.go new file mode 100644 index 0000000..6ad371b --- /dev/null +++ b/driver/redisdriver.go @@ -0,0 +1,145 @@ +package driver + +import ( + "context" + "errors" + "fmt" + "log" + "sync" + "time" + + "github.com/go-redis/redis/v8" + "github.com/libi/dcron/dlog" +) + +const ( + redisDefaultTimeout = 5 * time.Second +) + +type RedisDriver struct { + c *redis.Client + serviceName string + nodeID string + timeout time.Duration + logger dlog.Logger + started bool + + // this context is used to define + // the life time of this driver. + runtimeCtx context.Context + runtimeCancel context.CancelFunc + + sync.Mutex +} + +func newRedisDriver(redisClient *redis.Client) *RedisDriver { + rd := &RedisDriver{ + c: redisClient, + logger: &dlog.StdLogger{ + Log: log.Default(), + }, + timeout: redisDefaultTimeout, + } + rd.started = false + return rd +} + +func (rd *RedisDriver) Init(serviceName string, opts ...Option) { + rd.serviceName = serviceName + rd.nodeID = GetNodeId(rd.serviceName) + + for _, opt := range opts { + rd.withOption(opt) + } +} + +func (rd *RedisDriver) NodeID() string { + return rd.nodeID +} + +func (rd *RedisDriver) Start(ctx context.Context) (err error) { + rd.Lock() + defer rd.Unlock() + if rd.started { + err = errors.New("this driver is started") + return + } + rd.runtimeCtx, rd.runtimeCancel = context.WithCancel(context.TODO()) + rd.started = true + // register + err = rd.registerServiceNode() + if err != nil { + rd.logger.Errorf("register service error=%v", err) + return + } + // heartbeat timer + go rd.heartBeat() + return +} + +func (rd *RedisDriver) Stop(ctx context.Context) (err error) { + rd.Lock() + defer rd.Unlock() + rd.runtimeCancel() + rd.started = false + return +} + +func (rd *RedisDriver) GetNodes(ctx context.Context) (nodes []string, err error) { + mathStr := fmt.Sprintf("%s*", GetKeyPre(rd.serviceName)) + return rd.scan(ctx, mathStr) +} + +// private function + +func (rd *RedisDriver) heartBeat() { + tick := time.NewTicker(rd.timeout / 2) + for { + select { + case <-tick.C: + { + if err := rd.registerServiceNode(); err != nil { + rd.logger.Errorf("register service node error %+v", err) + } + } + case <-rd.runtimeCtx.Done(): + { + if err := rd.c.Del(context.Background(), rd.nodeID, rd.nodeID).Err(); err != nil { + rd.logger.Errorf("unregister service node error %+v", err) + } + return + } + } + } +} + +func (rd *RedisDriver) registerServiceNode() error { + return rd.c.SetEX(context.Background(), rd.nodeID, rd.nodeID, rd.timeout).Err() +} + +func (rd *RedisDriver) scan(ctx context.Context, matchStr string) ([]string, error) { + ret := make([]string, 0) + iter := rd.c.Scan(ctx, 0, matchStr, -1).Iterator() + for iter.Next(ctx) { + err := iter.Err() + if err != nil { + return nil, err + } + ret = append(ret, iter.Val()) + } + return ret, nil +} + +func (rd *RedisDriver) withOption(opt Option) (err error) { + switch opt.Type() { + case OptionTypeTimeout: + { + rd.timeout = opt.(TimeoutOption).timeout + } + case OptionTypeLogger: + { + rd.logger = opt.(LoggerOption).logger + } + } + return +} diff --git a/driver/redisdriver_test.go b/driver/redisdriver_test.go new file mode 100644 index 0000000..2761f0b --- /dev/null +++ b/driver/redisdriver_test.go @@ -0,0 +1,92 @@ +package driver_test + +import ( + "context" + "log" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/go-redis/redis/v8" + "github.com/libi/dcron/dlog" + "github.com/libi/dcron/driver" + "github.com/stretchr/testify/require" +) + +func testFuncNewRedisDriver(addr string) driver.DriverV2 { + log.Println("redis=", addr) + redisCli := redis.NewClient(&redis.Options{ + Addr: addr, + }) + return driver.NewRedisDriver(redisCli) +} + +func TestRedisDriver_GetNodes(t *testing.T) { + rds := miniredis.RunT(t) + drvs := make([]driver.DriverV2, 0) + N := 10 + for i := 0; i < N; i++ { + drv := testFuncNewRedisDriver(rds.Addr()) + drv.Init( + t.Name(), + driver.NewTimeoutOption(5*time.Second), + driver.NewLoggerOption(dlog.NewLoggerForTest(t))) + err := drv.Start(context.Background()) + require.Nil(t, err) + drvs = append(drvs, drv) + } + + for _, v := range drvs { + nodes, err := v.GetNodes(context.Background()) + require.Nil(t, err) + require.Equal(t, N, len(nodes)) + } + + for _, v := range drvs { + v.Stop(context.Background()) + } +} + +func TestRedisDriver_Stop(t *testing.T) { + var err error + var nodes []string + rds := miniredis.RunT(t) + drv1 := testFuncNewRedisDriver(rds.Addr()) + drv1.Init(t.Name(), + driver.NewTimeoutOption(5*time.Second), + driver.NewLoggerOption(dlog.NewLoggerForTest(t))) + + drv2 := testFuncNewRedisDriver(rds.Addr()) + drv2.Init(t.Name(), + driver.NewTimeoutOption(5*time.Second), + driver.NewLoggerOption(dlog.NewLoggerForTest(t))) + err = drv2.Start(context.Background()) + require.Nil(t, err) + + err = drv1.Start(context.Background()) + require.Nil(t, err) + + nodes, err = drv1.GetNodes(context.Background()) + require.Nil(t, err) + require.Len(t, nodes, 2) + + nodes, err = drv2.GetNodes(context.Background()) + require.Nil(t, err) + require.Len(t, nodes, 2) + + drv1.Stop(context.Background()) + + <-time.After(5 * time.Second) + nodes, err = drv2.GetNodes(context.Background()) + require.Nil(t, err) + require.Len(t, nodes, 1) + + err = drv1.Start(context.Background()) + require.Nil(t, err) + <-time.After(5 * time.Second) + nodes, err = drv2.GetNodes(context.Background()) + require.Nil(t, err) + require.Len(t, nodes, 2) + + drv2.Stop(context.Background()) +} diff --git a/driver/rediszsetdriver.go b/driver/rediszsetdriver.go new file mode 100644 index 0000000..968e695 --- /dev/null +++ b/driver/rediszsetdriver.go @@ -0,0 +1,140 @@ +package driver + +import ( + "context" + "errors" + "fmt" + "log" + "sync" + "time" + + "github.com/go-redis/redis/v8" + "github.com/libi/dcron/dlog" +) + +type RedisZSetDriver struct { + c *redis.Client + serviceName string + nodeID string + timeout time.Duration + logger dlog.Logger + started bool + + // this context is used to define + // the life time of this driver. + runtimeCtx context.Context + runtimeCancel context.CancelFunc + + sync.Mutex +} + +func newRedisZSetDriver(redisClient *redis.Client) *RedisZSetDriver { + rd := &RedisZSetDriver{ + c: redisClient, + logger: &dlog.StdLogger{ + Log: log.Default(), + }, + timeout: redisDefaultTimeout, + } + rd.started = false + return rd +} + +func (rd *RedisZSetDriver) Init(serviceName string, opts ...Option) { + rd.serviceName = serviceName + rd.nodeID = GetNodeId(serviceName) + for _, opt := range opts { + rd.withOption(opt) + } +} + +func (rd *RedisZSetDriver) NodeID() string { + return rd.nodeID +} + +func (rd *RedisZSetDriver) GetNodes(ctx context.Context) (nodes []string, err error) { + rd.Lock() + defer rd.Unlock() + sliceCmd := rd.c.ZRangeByScore(ctx, GetKeyPre(rd.serviceName), &redis.ZRangeBy{ + Min: fmt.Sprintf("%d", TimePre(time.Now(), rd.timeout)), + Max: "+inf", + }) + if err = sliceCmd.Err(); err != nil { + return nil, err + } else { + nodes = make([]string, len(sliceCmd.Val())) + copy(nodes, sliceCmd.Val()) + } + rd.logger.Infof("nodes=%v", nodes) + return +} +func (rd *RedisZSetDriver) Start(ctx context.Context) (err error) { + rd.Lock() + defer rd.Unlock() + if rd.started { + err = errors.New("this driver is started") + return + } + rd.runtimeCtx, rd.runtimeCancel = context.WithCancel(context.TODO()) + rd.started = true + // register + err = rd.registerServiceNode() + if err != nil { + rd.logger.Errorf("register service error=%v", err) + return + } + // heartbeat timer + go rd.heartBeat() + return +} +func (rd *RedisZSetDriver) Stop(ctx context.Context) (err error) { + rd.Lock() + defer rd.Unlock() + rd.runtimeCancel() + rd.started = false + return +} + +func (rd *RedisZSetDriver) withOption(opt Option) (err error) { + switch opt.Type() { + case OptionTypeTimeout: + { + rd.timeout = opt.(TimeoutOption).timeout + } + case OptionTypeLogger: + { + rd.logger = opt.(LoggerOption).logger + } + } + return +} + +// private function + +func (rd *RedisZSetDriver) heartBeat() { + tick := time.NewTicker(rd.timeout / 2) + for { + select { + case <-tick.C: + { + if err := rd.registerServiceNode(); err != nil { + rd.logger.Errorf("register service node error %+v", err) + } + } + case <-rd.runtimeCtx.Done(): + { + if err := rd.c.Del(context.Background(), rd.nodeID, rd.nodeID).Err(); err != nil { + rd.logger.Errorf("unregister service node error %+v", err) + } + return + } + } + } +} + +func (rd *RedisZSetDriver) registerServiceNode() error { + return rd.c.ZAdd(context.Background(), GetKeyPre(rd.serviceName), &redis.Z{ + Score: float64(time.Now().Unix()), + Member: rd.nodeID, + }).Err() +} diff --git a/driver/rediszsetdriver_test.go b/driver/rediszsetdriver_test.go new file mode 100644 index 0000000..12c8079 --- /dev/null +++ b/driver/rediszsetdriver_test.go @@ -0,0 +1,90 @@ +package driver_test + +import ( + "context" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/go-redis/redis/v8" + "github.com/libi/dcron/dlog" + "github.com/libi/dcron/driver" + "github.com/stretchr/testify/require" +) + +func testFuncNewRedisZSetDriver(addr string) driver.DriverV2 { + redisCli := redis.NewClient(&redis.Options{ + Addr: addr, + }) + return driver.NewRedisZSetDriver(redisCli) +} + +func TestRedisZSetDriver_GetNodes(t *testing.T) { + rds := miniredis.RunT(t) + drvs := make([]driver.DriverV2, 0) + N := 10 + for i := 0; i < N; i++ { + drv := testFuncNewRedisZSetDriver(rds.Addr()) + drv.Init( + t.Name(), + driver.NewTimeoutOption(5*time.Second), + driver.NewLoggerOption(dlog.NewLoggerForTest(t))) + err := drv.Start(context.Background()) + require.Nil(t, err) + drvs = append(drvs, drv) + } + + for _, v := range drvs { + nodes, err := v.GetNodes(context.Background()) + require.Nil(t, err) + require.Equal(t, N, len(nodes)) + } + + for _, v := range drvs { + v.Stop(context.Background()) + } +} + +func TestRedisZSetDriver_Stop(t *testing.T) { + var err error + var nodes []string + rds := miniredis.RunT(t) + drv1 := testFuncNewRedisZSetDriver(rds.Addr()) + drv1.Init(t.Name(), + driver.NewTimeoutOption(5*time.Second), + driver.NewLoggerOption(dlog.NewLoggerForTest(t))) + + drv2 := testFuncNewRedisZSetDriver(rds.Addr()) + drv2.Init(t.Name(), + driver.NewTimeoutOption(5*time.Second), + driver.NewLoggerOption(dlog.NewLoggerForTest(t))) + err = drv2.Start(context.Background()) + require.Nil(t, err) + + err = drv1.Start(context.Background()) + require.Nil(t, err) + + nodes, err = drv1.GetNodes(context.Background()) + require.Nil(t, err) + require.Len(t, nodes, 2) + + nodes, err = drv2.GetNodes(context.Background()) + require.Nil(t, err) + require.Len(t, nodes, 2) + + drv1.Stop(context.Background()) + + <-time.After(6 * time.Second) + nodes, err = drv2.GetNodes(context.Background()) + require.Nil(t, err) + require.Len(t, nodes, 1) + + err = drv1.Start(context.Background()) + require.Nil(t, err) + <-time.After(5 * time.Second) + nodes, err = drv2.GetNodes(context.Background()) + require.Nil(t, err) + require.Len(t, nodes, 2) + + drv2.Stop(context.Background()) +} diff --git a/driver/util.go b/driver/util.go index f3dfc10..dfe20e9 100644 --- a/driver/util.go +++ b/driver/util.go @@ -1,6 +1,10 @@ package driver -import "github.com/google/uuid" +import ( + "time" + + "github.com/google/uuid" +) // GlobalKeyPrefix is global redis key preifx const GlobalKeyPrefix = "distributed-cron:" @@ -20,3 +24,7 @@ func GetStableJobStore(serviceName string) string { func GetStableJobStoreTxKey(serviceName string) string { return GetKeyPre(serviceName) + "TX:stable-jobs" } + +func TimePre(t time.Time, preDuration time.Duration) int64 { + return t.Add(-preDuration).Unix() +} diff --git a/examples/example/README.md b/examples/example/README.md index 6346bfc..f4cf566 100644 --- a/examples/example/README.md +++ b/examples/example/README.md @@ -12,11 +12,24 @@ go build -o bin/example example.go ```bash # ./run.sh $number_of_process $number_of_cronjob # in linux -./run.sh 5 3 +./run.sh 5 10 +``` +## run 1 instance +```bash +# in linux +# ./run-instance.sh $sub_id_for_this_process $number_of_cronjob +./run-instance.sh 6 10 ``` -## stop +## stop all ```bash # in linux ./killexamples.sh +``` + +## stop 1 instance +```bash +# in linux +# ./kill-instance.sh $sub_id +./kill-instance.sh 2 ``` \ No newline at end of file diff --git a/examples/example/example.go b/examples/example/example.go index 7e7fd72..f7a8af3 100644 --- a/examples/example/example.go +++ b/examples/example/example.go @@ -1,7 +1,6 @@ package main import ( - "errors" "flag" "fmt" "log" @@ -14,9 +13,6 @@ import ( "github.com/libi/dcron" "github.com/libi/dcron/dlog" "github.com/libi/dcron/driver" - etcdDriver "github.com/libi/dcron/driver/etcd" - redisDriver "github.com/libi/dcron/driver/redis" - clientv3 "go.etcd.io/etcd/client/v3" ) const ( @@ -26,7 +22,6 @@ const ( var ( addr = flag.String("addr", "127.0.0.1:6379", "the addr of driver service") - driverType = flag.String("driver_type", "redis", "the driver type [redis/etcd]") serverName = flag.String("server_name", "server", "the server name of dcron in this process") subId = flag.String("sub_id", "1", "this process sub id in this server") jobNumber = flag.Int("jobnumber", 3, "there number of cron job") @@ -60,26 +55,14 @@ func (wj *WriteJob) Run() { } } -func getTheDriver() (driver.Driver, error) { - - if *driverType == DriverType_REDIS { - return redisDriver.NewDriver(&redis.Options{ - Addr: *addr, - }) - } else if *driverType == DriverType_ETCD { - return etcdDriver.NewEtcdDriver(&clientv3.Config{ - Endpoints: []string{*addr}, - }) - } - return nil, errors.New("driverType not suit") -} - func main() { flag.Parse() - driver, err := getTheDriver() - if err != nil { - panic(err) - } + var err error + + redisCli := redis.NewClient(&redis.Options{ + Addr: *addr, + }) + driver := driver.NewRedisDriver(redisCli) logger := &dlog.StdLogger{ Log: log.New(os.Stdout, "["+*subId+"]", log.LstdFlags), } diff --git a/examples/example/kill-instance.sh b/examples/example/kill-instance.sh new file mode 100755 index 0000000..8afd672 --- /dev/null +++ b/examples/example/kill-instance.sh @@ -0,0 +1,6 @@ +#!/bin/bash +kill_instance() { + ps -ef | grep example | grep "sub_id $1" | grep -v grep | awk '{print $2}' | xargs kill -9 +} + +kill_instance $1 \ No newline at end of file diff --git a/examples/example/run-instance.sh b/examples/example/run-instance.sh new file mode 100755 index 0000000..564a3d5 --- /dev/null +++ b/examples/example/run-instance.sh @@ -0,0 +1,5 @@ +start_example() { + nohup ./bin/example -sub_id $1 -jobnumber $2 & +} + +start_example $1 $2 \ No newline at end of file diff --git a/examples/stablejob/stablejob.go b/examples/stablejob/stablejob.go index 6ef508a..d13f916 100644 --- a/examples/stablejob/stablejob.go +++ b/examples/stablejob/stablejob.go @@ -9,7 +9,7 @@ import ( "github.com/go-redis/redis/v8" "github.com/libi/dcron" "github.com/libi/dcron/dlog" - redisDriver "github.com/libi/dcron/driver/redis" + "github.com/libi/dcron/driver" examplesCommon "github.com/libi/dcron/examples/common" ) @@ -45,10 +45,8 @@ func main() { Password: IEnv.RedisPassword, } - drv, err := redisDriver.NewDriver(redisOpts) - if err != nil { - logger.Fatal(err) - } + redisCli := redis.NewClient(redisOpts) + drv := driver.NewRedisDriver(redisCli) dcronInstance := dcron.NewDcronWithOption(IEnv.ServerName, drv, dcron.WithLogger(&dlog.StdLogger{ Log: logger, diff --git a/go.mod b/go.mod index ddc0fa0..2f34839 100644 --- a/go.mod +++ b/go.mod @@ -3,10 +3,11 @@ module github.com/libi/dcron go 1.16 require ( + github.com/alicebob/miniredis/v2 v2.30.1 github.com/go-redis/redis/v8 v8.11.5 github.com/google/uuid v1.1.2 github.com/robfig/cron/v3 v3.0.1 - github.com/stretchr/testify v1.7.0 + github.com/stretchr/testify v1.8.1 go.etcd.io/etcd/api/v3 v3.5.4 go.etcd.io/etcd/client/v3 v3.5.4 go.etcd.io/etcd/tests/v3 v3.5.4 diff --git a/go.sum b/go.sum index 27b808b..b35da92 100644 --- a/go.sum +++ b/go.sum @@ -21,6 +21,10 @@ github.com/alecthomas/template v0.0.0-20190718012654-fb15b899a751/go.mod h1:LOuy github.com/alecthomas/units v0.0.0-20151022065526-2efee857e7cf/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190717042225-c3de453c63f4/go.mod h1:ybxpYRFXyAe+OPACYpWeL0wqObRcbAqCMya13uyzqw0= github.com/alecthomas/units v0.0.0-20190924025748-f65c72e2690d/go.mod h1:rBZYJk541a8SKzHPHnH3zbiI+7dagKZ0cgpgrD7Fyho= +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a h1:HbKu58rmZpUGpz5+4FfNmIU+FmZg2P3Xaj2v2bfNWmk= +github.com/alicebob/gopher-json v0.0.0-20200520072559-a9ecdc9d1d3a/go.mod h1:SGnFV6hVsYE877CKEZ6tDNTjaSXYUk6QqoIK6PrAtcc= +github.com/alicebob/miniredis/v2 v2.30.1 h1:HM1rlQjq1bm9yQcsawJqSZBJ9AYgxvjkMsNtddh90+g= +github.com/alicebob/miniredis/v2 v2.30.1/go.mod h1:b25qWj4fCEsBeAAR2mlb0ufImGC6uH3VlUfb/HS5zKg= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= @@ -315,12 +319,17 @@ github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An github.com/spf13/viper v1.7.0/go.mod h1:8WkrPz2fc9jxqZNCJI/76HCieCp4Q8HaLFoCha5qpdg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/subosito/gotenv v1.2.0/go.mod h1:N0PQaV/YGNqwC0u51sEeR/aUtSLEXKX9iv69rRypqCw= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20201229170055-e5319fda7802 h1:uruHq4dN7GR16kFc5fp3d1RIYzJW5onx8Ybykw2YQFA= @@ -330,6 +339,8 @@ github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +github.com/yuin/gopher-lua v1.1.0 h1:BojcDhfyDWgU2f2TOzYK/g5p2gxMrku8oupLDqlnSqE= +github.com/yuin/gopher-lua v1.1.0/go.mod h1:GBR0iDaNXjAgGg9zfCvksxSRnQx76gclCIb7kdAd1Pw= go.etcd.io/bbolt v1.3.2/go.mod h1:IbVyRI1SCnLcuJnV2u8VeU0CEYM7e686BmAb1XKL+uU= go.etcd.io/bbolt v1.3.6 h1:/ecaJf0sk1l4l6V4awd65v2C3ILy7MSj+s/x1ADCIMU= go.etcd.io/bbolt v1.3.6/go.mod h1:qXsaaIqmgQH0T+OPdb99Bf+PKfBBQVAdyD6TY9G8XM4= @@ -466,6 +477,7 @@ golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20181026203630-95b1ffbd15a5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181107165924-66b7b1311ac8/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20181116152217-5ac8a444bdc5/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190204203706-41f3e6584952/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -612,8 +624,9 @@ gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b h1:h8qDotaEPuJATrMmW04NCwg7v22aHH28wwpauUhK9Oo= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190106161140-3f1c8253044a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= honnef.co/go/tools v0.0.0-20190418001031-e561f6794a2a/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= diff --git a/inodepool.go b/inodepool.go new file mode 100644 index 0000000..baa59f1 --- /dev/null +++ b/inodepool.go @@ -0,0 +1,11 @@ +package dcron + +import "context" + +type INodePool interface { + Start(ctx context.Context) error + CheckJobAvailable(jobName string) bool + Stop(ctx context.Context) error + + GetNodeID() string +} diff --git a/inodepool_test.go b/inodepool_test.go new file mode 100644 index 0000000..b84874d --- /dev/null +++ b/inodepool_test.go @@ -0,0 +1,161 @@ +package dcron_test + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/alicebob/miniredis/v2" + "github.com/go-redis/redis/v8" + "github.com/libi/dcron" + "github.com/libi/dcron/consistenthash" + "github.com/libi/dcron/driver" + "github.com/stretchr/testify/suite" + clientv3 "go.etcd.io/etcd/client/v3" + "go.etcd.io/etcd/tests/v3/integration" +) + +type TestINodePoolSuite struct { + suite.Suite + + rds *miniredis.Miniredis + etcdsvr integration.LazyCluster + defaultHashReplicas int +} + +func (ts *TestINodePoolSuite) SetupTest() { + ts.defaultHashReplicas = 10 +} + +func (ts *TestINodePoolSuite) TearDownTest() { + if ts.rds != nil { + ts.rds.Close() + ts.rds = nil + } + if ts.etcdsvr != nil { + ts.etcdsvr.Terminate() + ts.etcdsvr = nil + } +} + +func (ts *TestINodePoolSuite) setUpRedis() { + ts.rds = miniredis.RunT(ts.T()) +} + +func (ts *TestINodePoolSuite) setUpEtcd() { + ts.etcdsvr = integration.NewLazyCluster() +} + +func (ts *TestINodePoolSuite) declareRedisDrivers(clients *[]*redis.Client, drivers *[]driver.DriverV2, numberOfNodes int) { + for i := 0; i < numberOfNodes; i++ { + *clients = append(*clients, redis.NewClient(&redis.Options{ + Addr: ts.rds.Addr(), + })) + *drivers = append(*drivers, driver.NewRedisDriver((*clients)[i])) + } +} + +func (ts *TestINodePoolSuite) declareEtcdDrivers(clients *[]*clientv3.Client, drivers *[]driver.DriverV2, numberOfNodes int) { + for i := 0; i < numberOfNodes; i++ { + cli, err := clientv3.New(clientv3.Config{ + Endpoints: ts.etcdsvr.EndpointsV3(), + }) + if err != nil { + ts.T().Fatal(err) + } + *clients = append(*clients, cli) + *drivers = append(*drivers, driver.NewEtcdDriver((*clients)[i])) + } +} + +func (ts *TestINodePoolSuite) declareRedisZSetDrivers(clients *[]*redis.Client, drivers *[]driver.DriverV2, numberOfNodes int) { + for i := 0; i < numberOfNodes; i++ { + *clients = append(*clients, redis.NewClient(&redis.Options{ + Addr: ts.rds.Addr(), + })) + *drivers = append(*drivers, driver.NewRedisZSetDriver((*clients)[i])) + } +} + +func (ts *TestINodePoolSuite) runCheckJobAvailable(numberOfNodes int, ServiceName string, nodePools *[]dcron.INodePool, updateDuration time.Duration) { + for i := 0; i < numberOfNodes; i++ { + err := (*nodePools)[i].Start(context.Background()) + if err != nil { + ts.T().Fail() + } + } + <-time.After(updateDuration * 2) + ring := consistenthash.New(ts.defaultHashReplicas, nil) + for _, v := range *nodePools { + ring.Add(v.GetNodeID()) + } + + for i := 0; i < 10000; i++ { + for j := 0; j < numberOfNodes; j++ { + ts.Require().Equal( + (*nodePools)[j].CheckJobAvailable(strconv.Itoa(i)), + (ring.Get(strconv.Itoa(i)) == (*nodePools)[j].GetNodeID()), + ) + } + } + +} + +func (ts *TestINodePoolSuite) TestMultiNodesRedis() { + var clients []*redis.Client + var drivers []driver.DriverV2 + var nodePools []dcron.INodePool + + numberOfNodes := 5 + ServiceName := "TestMultiNodesRedis" + updateDuration := 2 * time.Second + ts.setUpRedis() + ts.declareRedisDrivers(&clients, &drivers, numberOfNodes) + + for i := 0; i < numberOfNodes; i++ { + nodePools = append(nodePools, dcron.NewNodePool(ServiceName, drivers[i], updateDuration, ts.defaultHashReplicas, nil)) + } + ts.runCheckJobAvailable(numberOfNodes, ServiceName, &nodePools, updateDuration) +} + +func (ts *TestINodePoolSuite) TestMultiNodesEtcd() { + var clients []*clientv3.Client + var drivers []driver.DriverV2 + var nodePools []dcron.INodePool + + numberOfNodes := 5 + ServiceName := "TestMultiNodesEtcd" + updateDuration := 8 * time.Second + + ts.setUpEtcd() + ts.declareEtcdDrivers(&clients, &drivers, numberOfNodes) + + for i := 0; i < numberOfNodes; i++ { + nodePools = append(nodePools, dcron.NewNodePool(ServiceName, drivers[i], updateDuration, ts.defaultHashReplicas, nil)) + } + ts.runCheckJobAvailable(numberOfNodes, ServiceName, &nodePools, updateDuration) +} + +func (ts *TestINodePoolSuite) TestMultiNodesRedisZSet() { + var clients []*redis.Client + var drivers []driver.DriverV2 + var nodePools []dcron.INodePool + + numberOfNodes := 5 + ServiceName := "TestMultiNodesEtcd" + updateDuration := 2 * time.Second + + ts.setUpRedis() + ts.declareRedisZSetDrivers(&clients, &drivers, numberOfNodes) + + for i := 0; i < numberOfNodes; i++ { + nodePools = append(nodePools, dcron.NewNodePool(ServiceName, drivers[i], updateDuration, ts.defaultHashReplicas, nil)) + } + ts.runCheckJobAvailable(numberOfNodes, ServiceName, &nodePools, updateDuration) +} + +func TestTestINodePoolSuite(t *testing.T) { + s := new(TestINodePoolSuite) + suite.Run(t, s) +} diff --git a/node_pool.go b/node_pool.go deleted file mode 100644 index 6667dca..0000000 --- a/node_pool.go +++ /dev/null @@ -1,101 +0,0 @@ -package dcron - -import ( - "sync" - "sync/atomic" - "time" - - "github.com/libi/dcron/consistenthash" - "github.com/libi/dcron/driver" -) - -// NodePool is a node pool -type NodePool struct { - serviceName string - NodeID string - - rwMut sync.RWMutex - nodes *consistenthash.Map - - Driver driver.Driver - hashReplicas int - hashFn consistenthash.Hash - updateDuration time.Duration - - dcron *Dcron -} - -func newNodePool(serverName string, driver driver.Driver, dcron *Dcron, updateDuration time.Duration, hashReplicas int) (*NodePool, error) { - - err := driver.Ping() - if err != nil { - return nil, err - } - - nodePool := &NodePool{ - Driver: driver, - serviceName: serverName, - dcron: dcron, - hashReplicas: hashReplicas, - updateDuration: updateDuration, - } - return nodePool, nil -} - -// StartPool Start Service Watch Pool -func (np *NodePool) StartPool() error { - var err error - np.Driver.SetTimeout(np.updateDuration) - np.NodeID, err = np.Driver.RegisterServiceNode(np.serviceName) - if err != nil { - return err - } - np.Driver.SetHeartBeat(np.NodeID) - - err = np.updatePool() - if err != nil { - return err - } - - go np.tickerUpdatePool() - return nil -} - -func (np *NodePool) updatePool() error { - nodes, err := np.Driver.GetServiceNodeList(np.serviceName) - if err != nil { - return err - } - - np.rwMut.Lock() - defer np.rwMut.Unlock() - np.nodes = consistenthash.New(np.hashReplicas, np.hashFn) - for _, node := range nodes { - np.nodes.Add(node) - } - return nil -} -func (np *NodePool) tickerUpdatePool() { - tickers := time.NewTicker(np.updateDuration) - for range tickers.C { - if atomic.LoadInt32(&np.dcron.running) == dcronRunning { - err := np.updatePool() - if err != nil { - np.dcron.logger.Infof("update node pool error %+v", err) - } - } else { - tickers.Stop() - return - } - } -} - -// PickNodeByJobName : 使用一致性hash算法根据任务名获取一个执行节点 -func (np *NodePool) PickNodeByJobName(jobName string) string { - np.rwMut.RLock() - defer np.rwMut.RUnlock() - if np.nodes.IsEmpty() { - return "" - } - return np.nodes.Get(jobName) -} diff --git a/nodepool.go b/nodepool.go new file mode 100644 index 0000000..8c5c288 --- /dev/null +++ b/nodepool.go @@ -0,0 +1,141 @@ +package dcron + +import ( + "context" + "log" + "sync" + "time" + + "github.com/libi/dcron/consistenthash" + "github.com/libi/dcron/dlog" + "github.com/libi/dcron/driver" +) + +// NodePool is a node pool +type NodePool struct { + serviceName string + nodeID string + + rwMut sync.RWMutex + nodes *consistenthash.Map + + driver driver.DriverV2 + hashReplicas int + hashFn consistenthash.Hash + updateDuration time.Duration + + logger dlog.Logger + stopChan chan int + preNodes []string // sorted +} + +func NewNodePool(serviceName string, drv driver.DriverV2, updateDuration time.Duration, hashReplicas int, logger dlog.Logger) INodePool { + np := &NodePool{ + serviceName: serviceName, + driver: drv, + hashReplicas: hashReplicas, + updateDuration: updateDuration, + logger: &dlog.StdLogger{ + Log: log.Default(), + }, + stopChan: make(chan int, 1), + } + if logger != nil { + np.logger = logger + } + np.driver.Init(serviceName, + driver.NewTimeoutOption(updateDuration), + driver.NewLoggerOption(np.logger)) + return np +} + +func (np *NodePool) Start(ctx context.Context) (err error) { + err = np.driver.Start(ctx) + if err != nil { + np.logger.Errorf("start pool error: %v", err) + return + } + np.nodeID = np.driver.NodeID() + nowNodes, err := np.driver.GetNodes(ctx) + if err != nil { + np.logger.Errorf("get nodes error: %v", err) + return + } + np.updateHashRing(nowNodes) + go np.waitingForHashRing() + return +} + +// Check if this job can be run in this node. +func (np *NodePool) CheckJobAvailable(jobName string) bool { + np.rwMut.RLock() + defer np.rwMut.RUnlock() + if np.nodes == nil { + np.logger.Errorf("nodeID=%s, np.nodes is nil", np.nodeID) + } + if np.nodes.IsEmpty() { + return false + } + targetNode := np.nodes.Get(jobName) + if np.nodeID == targetNode { + np.logger.Infof("job %s, running in node: %s", jobName, targetNode) + } + return np.nodeID == targetNode +} + +func (np *NodePool) Stop(ctx context.Context) error { + np.stopChan <- 1 + np.driver.Stop(ctx) + np.preNodes = make([]string, 0) + return nil +} + +func (np *NodePool) GetNodeID() string { + return np.nodeID +} + +func (np *NodePool) waitingForHashRing() { + tick := time.NewTicker(np.updateDuration) + for { + select { + case <-tick.C: + nowNodes, err := np.driver.GetNodes(context.Background()) + if err != nil { + np.logger.Errorf("get nodes error %v", err) + continue + } + np.updateHashRing(nowNodes) + case <-np.stopChan: + return + } + } +} + +func (np *NodePool) updateHashRing(nodes []string) { + np.rwMut.Lock() + defer np.rwMut.Unlock() + if np.equalRing(nodes) { + np.logger.Infof("nowNodes=%v, preNodes=%v", nodes, np.preNodes) + return + } + np.logger.Infof("update hashRing nodes=%+v", nodes) + np.preNodes = make([]string, len(nodes)) + copy(np.preNodes, nodes) + np.nodes = consistenthash.New(np.hashReplicas, np.hashFn) + for _, v := range nodes { + np.nodes.Add(v) + } +} + +func (np *NodePool) equalRing(a []string) bool { + if len(a) == len(np.preNodes) { + la := len(a) + for i := 0; i < la; i++ { + if a[i] != np.preNodes[i] { + return false + } + } + return true + } + return false +}