From 6f6c610ab40118b314fc2ce9dd91f60ebbd31c2b Mon Sep 17 00:00:00 2001
From: Jon Lee <techlee@qq.com>
Date: Wed, 20 Jul 2022 15:00:08 +0800
Subject: [PATCH] feat: cache support tag

---
 cache/memory_store.go     |  26 +++--
 cache/redis_store.go      | 199 +++++++++++++++++++++++++++++++++++++-
 cache/redis_store_test.go |  54 ++++++++++-
 cache/repository.go       |   5 +
 cache/store.go            |   5 +
 cache/tag_set.go          |  61 ++++++++++++
 cache/utils.go            |  13 +++
 7 files changed, 349 insertions(+), 14 deletions(-)
 create mode 100644 cache/tag_set.go
 create mode 100644 cache/utils.go

diff --git a/cache/memory_store.go b/cache/memory_store.go
index ebf8457..259d08a 100644
--- a/cache/memory_store.go
+++ b/cache/memory_store.go
@@ -15,7 +15,7 @@ type item struct {
 
 // Expired Returns true if the item has expired.
 func (item item) Expired() bool {
-	if item.Expiration == 0 {
+	if item.Expiration < 0 {
 		return false
 	}
 	return time.Now().UnixNano() > item.Expiration
@@ -65,8 +65,8 @@ func (s *MemoryStore) Get(key string, val interface{}) error {
 
 // Put set cached value with key and expire time.
 func (s *MemoryStore) Put(key string, val interface{}, timeout time.Duration) error {
-	var e int64
-	if timeout > 0 {
+	var e int64 = -1
+	if timeout >= 0 {
 		e = time.Now().Add(timeout).UnixNano()
 	}
 
@@ -78,7 +78,7 @@ func (s *MemoryStore) Put(key string, val interface{}, timeout time.Duration) er
 		Expiration: e,
 	}
 
-	if e > 0 {
+	if e >= 0 {
 		s.DeleteExpired()
 	}
 
@@ -133,6 +133,11 @@ func (s *MemoryStore) Decrement(key string, value ...int) (int, error) {
 	return by, nil
 }
 
+// Forever Store an item in the cache indefinitely.
+func (s *MemoryStore) Forever(key string, val interface{}) error {
+	return s.Put(key, val, 0)
+}
+
 // Exist check cache's existence in memory.
 func (s *MemoryStore) Exist(key string) bool {
 	s.mu.RLock()
@@ -149,8 +154,8 @@ func (s *MemoryStore) Exist(key string) bool {
 
 // Expire set value expire time.
 func (s *MemoryStore) Expire(key string, timeout time.Duration) error {
-	var e int64
-	if timeout > 0 {
+	var e int64 = -1
+	if timeout >= 0 {
 		e = time.Now().Add(timeout).UnixNano()
 	}
 
@@ -165,7 +170,7 @@ func (s *MemoryStore) Expire(key string, timeout time.Duration) error {
 	item.Expiration = e
 	s.items[s.prefix+key] = item
 
-	if e > 0 {
+	if e >= 0 {
 		s.DeleteExpired()
 	}
 
@@ -188,6 +193,11 @@ func (s *MemoryStore) Flush() error {
 	return nil
 }
 
+func (s *MemoryStore) Tags(names ...string) Store {
+	// tags not be supported
+	return s
+}
+
 func (s *MemoryStore) TTL(key string) (int64, error) {
 	s.mu.RLock()
 	defer s.mu.RUnlock()
@@ -230,7 +240,7 @@ func (s *MemoryStore) DeleteExpired() {
 
 	smallestDuration := 0 * time.Nanosecond
 	for key, item := range s.items {
-		if item.Expiration == 0 {
+		if item.Expiration < 0 {
 			continue
 		}
 		// "Inlining" of expired
diff --git a/cache/redis_store.go b/cache/redis_store.go
index 9e3d6cf..ff083d4 100644
--- a/cache/redis_store.go
+++ b/cache/redis_store.go
@@ -3,17 +3,25 @@ package cache
 import (
 	"encoding/json"
 	"fmt"
+	"strings"
 	"time"
 
 	"github.com/gomodule/redigo/redis"
 )
 
+// ReferenceKeyForever Forever reference key.
+const ReferenceKeyForever = "forever_ref"
+
+// ReferenceKeyStandard Standard reference key.
+const ReferenceKeyStandard = "standard_ref"
+
 type RedisStore struct {
 	pool   *redis.Pool // redis connection pool
+	tagSet *TagSet
 	prefix string
 }
 
-// NewStore Create a redis cache store
+// NewRedisStore Create a redis cache store
 func NewRedisStore(pool *redis.Pool, prefix string) *RedisStore {
 	s := RedisStore{}
 	return s.SetPool(pool).SetPrefix(prefix)
@@ -38,6 +46,12 @@ func (s *RedisStore) Put(key string, val interface{}, timeout time.Duration) err
 	if err != nil {
 		return err
 	}
+
+	err = s.pushStandardKeys(key)
+	if err != nil {
+		return err
+	}
+
 	c := s.pool.Get()
 	defer c.Close()
 	_, err = c.Do("SETEX", s.prefix+key, int64(timeout/time.Second), string(b))
@@ -46,6 +60,11 @@ func (s *RedisStore) Put(key string, val interface{}, timeout time.Duration) err
 
 // Increment the value of an item in the cache.
 func (s *RedisStore) Increment(key string, value ...int) (int, error) {
+	err := s.pushStandardKeys(key)
+	if err != nil {
+		return 0, err
+	}
+
 	c := s.pool.Get()
 	defer c.Close()
 
@@ -59,6 +78,11 @@ func (s *RedisStore) Increment(key string, value ...int) (int, error) {
 
 // Decrement the value of an item in the cache.
 func (s *RedisStore) Decrement(key string, value ...int) (int, error) {
+	err := s.pushStandardKeys(key)
+	if err != nil {
+		return 0, err
+	}
+
 	c := s.pool.Get()
 	defer c.Close()
 
@@ -70,6 +94,24 @@ func (s *RedisStore) Decrement(key string, value ...int) (int, error) {
 	return redis.Int(c.Do("DECRBY", s.prefix+key, by))
 }
 
+// Forever Store an item in the cache indefinitely.
+func (s *RedisStore) Forever(key string, val interface{}) error {
+	b, err := json.Marshal(val)
+	if err != nil {
+		return err
+	}
+
+	err = s.pushForeverKeys(key)
+	if err != nil {
+		return err
+	}
+
+	c := s.pool.Get()
+	defer c.Close()
+	_, err = c.Do("SET", s.prefix+key, string(b))
+	return err
+}
+
 // Exist check cache's existence in redis.
 func (s *RedisStore) Exist(key string) bool {
 	c := s.pool.Get()
@@ -100,6 +142,26 @@ func (s *RedisStore) Forget(key string) error {
 
 // Remove all items from the cache.
 func (s *RedisStore) Flush() error {
+	if s.tagSet != nil {
+		err := s.deleteForeverKeys()
+		if err != nil {
+			return err
+		}
+		err = s.deleteStandardKeys()
+		if err != nil {
+			return err
+		}
+		err = s.tagSet.Reset()
+		if err != nil {
+			return err
+		}
+		return nil
+	}
+
+	return s.flush()
+}
+
+func (s *RedisStore) flush() error {
 	c := s.pool.Get()
 	defer c.Close()
 
@@ -121,12 +183,34 @@ func (s *RedisStore) Flush() error {
 			break
 		}
 	}
-	for _, key := range keys {
-		if _, err = c.Do("DEL", key); err != nil {
-			return err
+
+	length := len(keys)
+	if length == 0 {
+		return nil
+	}
+
+	var keysChunk []interface{}
+	for i, key := range keys {
+		keysChunk = append(keysChunk, key)
+		if i == length-1 || len(keysChunk) == 1000 {
+			_, err = c.Do("DEL", keysChunk...)
+			if err != nil {
+				return err
+			}
 		}
 	}
-	return err
+
+	return nil
+}
+
+func (s *RedisStore) Tags(names ...string) Store {
+	if len(names) == 0 {
+		return s
+	}
+	ss := s.clone()
+	ss.tagSet = NewTagSet(s, names)
+
+	return ss
 }
 
 func (s *RedisStore) TTL(key string) (int64, error) {
@@ -156,3 +240,108 @@ func (s *RedisStore) SetPrefix(prefix string) *RedisStore {
 	}
 	return s
 }
+
+func (s *RedisStore) clone() *RedisStore {
+	return &RedisStore{
+		pool:   s.pool,
+		prefix: s.prefix,
+	}
+}
+
+func (s *RedisStore) pushStandardKeys(key string) error {
+	return s.pushKeys(key, ReferenceKeyStandard)
+}
+
+func (s *RedisStore) pushForeverKeys(key string) error {
+	return s.pushKeys(key, ReferenceKeyForever)
+}
+
+func (s *RedisStore) pushKeys(key, reference string) error {
+	if s.tagSet == nil {
+		return nil
+	}
+
+	namespace, err := s.tagSet.GetNamespace()
+	if err != nil {
+		return err
+	}
+
+	fullKey := s.prefix + key
+	segments := strings.Split(namespace, "|")
+
+	c := s.pool.Get()
+	defer c.Close()
+	for _, segment := range segments {
+		_, err = c.Do("SADD", s.referenceKey(segment, reference), fullKey)
+		if err != nil {
+			return err
+		}
+	}
+	return nil
+}
+
+func (s *RedisStore) deleteStandardKeys() error {
+	return s.deleteKeysByReference(ReferenceKeyStandard)
+}
+
+func (s *RedisStore) deleteForeverKeys() error {
+	return s.deleteKeysByReference(ReferenceKeyForever)
+}
+
+func (s *RedisStore) deleteKeysByReference(reference string) error {
+	if s.tagSet == nil {
+		return nil
+	}
+
+	namespace, err := s.tagSet.GetNamespace()
+	if err != nil {
+		return err
+	}
+	segments := strings.Split(namespace, "|")
+	c := s.pool.Get()
+	defer c.Close()
+
+	for _, segment := range segments {
+		segment = s.referenceKey(segment, reference)
+		err = s.deleteKeys(segment)
+		if err != nil {
+			return err
+		}
+		_, err = c.Do("DEL", segment)
+		if err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func (s *RedisStore) deleteKeys(referenceKey string) error {
+	c := s.pool.Get()
+	defer c.Close()
+	keys, err := redis.Strings(c.Do("SMEMBERS", referenceKey))
+	if err != nil {
+		return err
+	}
+	var length = len(keys)
+	if length == 0 {
+		return nil
+	}
+
+	var keysChunk []interface{}
+	for i, key := range keys {
+		keysChunk = append(keysChunk, key)
+		if i == length-1 || len(keysChunk) == 1000 {
+			_, err = c.Do("DEL", keysChunk...)
+			if err != nil {
+				return err
+			}
+		}
+	}
+
+	return nil
+}
+
+func (s *RedisStore) referenceKey(segment, suffix string) string {
+	return s.prefix + segment + ":" + suffix
+}
diff --git a/cache/redis_store_test.go b/cache/redis_store_test.go
index c5cea15..49fa084 100644
--- a/cache/redis_store_test.go
+++ b/cache/redis_store_test.go
@@ -16,10 +16,18 @@ func GetPool() *redis.Pool {
 		Wait:        true,
 		// Other pool configuration not shown in this example.
 		Dial: func() (redis.Conn, error) {
-			c, err := redis.Dial("tcp", "127.0.0.1:6379")
+			c, err := redis.Dial("tcp", "10.0.41.242:6379")
 			if err != nil {
 				return nil, err
 			}
+			if _, err := c.Do("AUTH", "abc-123"); err != nil {
+				c.Close()
+				return nil, err
+			}
+			if _, err := c.Do("SELECT", 0); err != nil {
+				c.Close()
+				return nil, err
+			}
 			return c, nil
 		},
 	}
@@ -132,3 +140,47 @@ func TestRedisStoreFlush(t *testing.T) {
 		t.Error(errors.New("Expect false"))
 	}
 }
+
+func TestRedisStore_Tags(t *testing.T) {
+	cache := getRedisStore()
+	key := "john"
+	value := "LosAngeles"
+	err := cache.Tags("people", "artists").Put(key, value, time.Hour)
+	if err != nil {
+		t.Fatal(err)
+	}
+	var val1 string
+	cache.Get(key, &val1)
+	if value != val1 {
+		t.Errorf("%s != %s", value, val1)
+	}
+
+	var val2 string
+	cache.Tags("people").Get(key, &val2)
+	if value != val2 {
+		t.Errorf("%s != %s", value, val2)
+	}
+
+	var val3 string
+	cache.Tags("artists").Get(key, &val3)
+	if value != val3 {
+		t.Errorf("%s != %s", value, val3)
+	}
+
+	cache.Tags("people").Put("bob", "NewYork", time.Hour)
+
+	err = cache.Tags("artists").Flush()
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	err = cache.Tags("artists").Get(key, &val1)
+	if err == nil {
+		t.Fatal("err should not be nil")
+	}
+
+	cache.Tags("people").Get("bob", &val2)
+	if "NewYork" != val2 {
+		t.Errorf("%s != %s", "NewYork", val2)
+	}
+}
diff --git a/cache/repository.go b/cache/repository.go
index 30cf2fb..030639f 100644
--- a/cache/repository.go
+++ b/cache/repository.go
@@ -103,6 +103,11 @@ func (r *Repository) Clear() error {
 	return r.store.Flush()
 }
 
+// Tags Begin executing a new tags operation if the store supports it.
+func (r *Repository) Tags(names ...string) *Repository {
+	return NewRepository(r.store.Tags(names...))
+}
+
 // TTL get the ttl of the key.
 func (r *Repository) TTL(key string) (int64, error) {
 	return r.store.TTL(key)
diff --git a/cache/store.go b/cache/store.go
index f2795a8..ecc31ac 100644
--- a/cache/store.go
+++ b/cache/store.go
@@ -17,6 +17,9 @@ type Store interface {
 	// Decrement the value of an item in the cache.
 	Decrement(key string, value ...int) (int, error)
 
+	// Forever Store an item in the cache indefinitely.
+	Forever(key string, val interface{}) error
+
 	// Exist check cache's existence in redis.
 	Exist(key string) bool
 
@@ -29,6 +32,8 @@ type Store interface {
 	// Flush Remove all items from the cache.
 	Flush() error
 
+	Tags(names ...string) Store
+
 	// TTL get the ttl of the key.
 	TTL(key string) (int64, error)
 }
diff --git a/cache/tag_set.go b/cache/tag_set.go
new file mode 100644
index 0000000..6a1ae5b
--- /dev/null
+++ b/cache/tag_set.go
@@ -0,0 +1,61 @@
+package cache
+
+import (
+	"strings"
+)
+
+type TagSet struct {
+	names []string
+	store Store
+}
+
+func NewTagSet(store Store, names []string) *TagSet {
+	return &TagSet{names: names, store: store}
+}
+
+func (t *TagSet) Reset() error {
+	for _, name := range t.names {
+		_, err := t.ResetTag(name)
+		if err != nil {
+			return err
+		}
+	}
+
+	return nil
+}
+
+func (t *TagSet) ResetTag(name string) (string, error) {
+	id := Sha1(name)
+
+	err := t.store.Forever(t.TagKey(name), id)
+
+	return id, err
+}
+
+func (t *TagSet) GetNamespace() (string, error) {
+	var err error
+	var names = make([]string, len(t.names))
+	for i, name := range t.names {
+		name, err = t.TagId(name)
+		if err != nil {
+			return "", err
+		}
+		names[i] = name
+	}
+	return strings.Join(names, "|"), nil
+}
+
+func (t *TagSet) TagId(name string) (string, error) {
+	var id string
+	tagKey := t.TagKey(name)
+	err := t.store.Get(tagKey, &id)
+	if err != nil {
+		return t.ResetTag(name)
+	}
+
+	return id, nil
+}
+
+func (t *TagSet) TagKey(name string) string {
+	return "tag:" + name + ":key"
+}
diff --git a/cache/utils.go b/cache/utils.go
new file mode 100644
index 0000000..79f4bb6
--- /dev/null
+++ b/cache/utils.go
@@ -0,0 +1,13 @@
+package cache
+
+import (
+	"crypto/sha1"
+	"fmt"
+)
+
+// Sha1 Calculate the sha1 hash of a string
+func Sha1(str string) string {
+	h := sha1.New()
+	_, _ = h.Write([]byte(str))
+	return fmt.Sprintf("%x", h.Sum(nil))
+}