Skip to content

Commit

Permalink
improve cache consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenghaoz committed Jun 5, 2021
1 parent 3adbeb8 commit cda9ce9
Show file tree
Hide file tree
Showing 22 changed files with 796 additions and 724 deletions.
49 changes: 33 additions & 16 deletions master/local_cache.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
// Copyright 2021 gorse Project Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package master

import (
Expand All @@ -8,15 +22,17 @@ import (
"path/filepath"
)

// LocalCache is local cache for the master node.
type LocalCache struct {
path string
ModelName string
ModelVersion int64
Model ranking.Model
ModelScore ranking.Score
UserIndex base.Index
path string
RankingModelName string
RankingModelVersion int64
RankingModel ranking.Model
RankingScore ranking.Score
UserIndex base.Index
}

// LoadLocalCache loads local cache from a file.
func LoadLocalCache(path string) (*LocalCache, error) {
state := &LocalCache{path: path}
// check if file exists
Expand All @@ -34,27 +50,27 @@ func LoadLocalCache(path string) (*LocalCache, error) {
}
decoder := gob.NewDecoder(f)
// 1. model name
err = decoder.Decode(&state.ModelName)
err = decoder.Decode(&state.RankingModelName)
if err != nil {
return state, err
}
// 2. model version
err = decoder.Decode(&state.ModelVersion)
err = decoder.Decode(&state.RankingModelVersion)
if err != nil {
return state, err
}
// 3. model
state.Model, err = ranking.NewModel(state.ModelName, nil)
state.RankingModel, err = ranking.NewModel(state.RankingModelName, nil)
if err != nil {
return state, err
}
err = decoder.Decode(state.Model)
err = decoder.Decode(state.RankingModel)
if err != nil {
return state, err
}
state.Model.SetParams(state.Model.GetParams())
state.RankingModel.SetParams(state.RankingModel.GetParams())
// 4. model score
err = decoder.Decode(&state.ModelScore)
err = decoder.Decode(&state.RankingScore)
if err != nil {
return state, err
}
Expand All @@ -67,6 +83,7 @@ func LoadLocalCache(path string) (*LocalCache, error) {
return state, nil
}

// WriteLocalCache writes local cache to a file.
func (c *LocalCache) WriteLocalCache() error {
// create parent folder if not exists
parent := filepath.Dir(c.path)
Expand All @@ -83,22 +100,22 @@ func (c *LocalCache) WriteLocalCache() error {
}
encoder := gob.NewEncoder(f)
// 1. model name
err = encoder.Encode(c.ModelName)
err = encoder.Encode(c.RankingModelName)
if err != nil {
return err
}
// 2. model version
err = encoder.Encode(c.ModelVersion)
err = encoder.Encode(c.RankingModelVersion)
if err != nil {
return err
}
// 3. model
err = encoder.Encode(c.Model)
err = encoder.Encode(c.RankingModel)
if err != nil {
return err
}
// 4. model score
err = encoder.Encode(c.ModelScore)
err = encoder.Encode(c.RankingScore)
if err != nil {
return err
}
Expand Down
83 changes: 39 additions & 44 deletions master/master.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,12 @@ type Master struct {
userIndexMutex sync.Mutex

// personal ranking model
prModel ranking.Model
prModelName string
prVersion int64
prScore ranking.Score
prMutex sync.Mutex
prSearcher *ranking.ModelSearcher

// factorization machine
//fmModel ctr.FactorizationMachine
//ctrVersion int64
//fmMutex sync.mutex
rankingModel ranking.Model
rankingModelName string
rankingModelVersion int64
rankingScore ranking.Score
rankingModelMutex sync.Mutex
rankingModelSearcher *ranking.ModelSearcher

localCache *LocalCache
}
Expand All @@ -78,13 +73,13 @@ func NewMaster(cfg *config.Config) *Master {
return &Master{
nodesInfo: make(map[string]*Node),
// init versions
prVersion: rand.Int63(),
rankingModelVersion: rand.Int63(),
// ctrVersion: rand.Int63(),
userIndexVersion: rand.Int63(),
// default model
prModelName: "bpr",
prModel: ranking.NewBPR(nil),
prSearcher: ranking.NewModelSearcher(cfg.Recommend.SearchEpoch, cfg.Recommend.SearchTrials),
rankingModelName: "bpr",
rankingModel: ranking.NewBPR(nil),
rankingModelSearcher: ranking.NewModelSearcher(cfg.Recommend.SearchEpoch, cfg.Recommend.SearchTrials),
RestServer: server.RestServer{
GorseConfig: cfg,
HttpHost: cfg.Master.HttpHost,
Expand All @@ -103,16 +98,16 @@ func (m *Master) Serve() {
if err != nil {
base.Logger().Error("failed to load local cache", zap.Error(err))
}
if m.localCache.Model != nil {
if m.localCache.RankingModel != nil {
base.Logger().Info("load cached model",
zap.String("model_name", m.localCache.ModelName),
zap.String("model_version", base.Hex(m.localCache.ModelVersion)),
zap.Float32("model_score", m.localCache.ModelScore.NDCG),
zap.Any("params", m.localCache.Model.GetParams()))
m.prModel = m.localCache.Model
m.prModelName = m.localCache.ModelName
m.prVersion = m.localCache.ModelVersion
m.prScore = m.localCache.ModelScore
zap.String("model_name", m.localCache.RankingModelName),
zap.String("model_version", base.Hex(m.localCache.RankingModelVersion)),
zap.Float32("model_score", m.localCache.RankingScore.NDCG),
zap.Any("params", m.localCache.RankingModel.GetParams()))
m.rankingModel = m.localCache.RankingModel
m.rankingModelName = m.localCache.RankingModelName
m.rankingModelVersion = m.localCache.RankingModelVersion
m.rankingScore = m.localCache.RankingScore
}

// create cluster meta cache
Expand All @@ -126,16 +121,16 @@ func (m *Master) Serve() {
}

// connect data database
m.DataStore, err = data.Open(m.GorseConfig.Database.DataStore)
m.DataClient, err = data.Open(m.GorseConfig.Database.DataStore)
if err != nil {
base.Logger().Fatal("failed to connect data database", zap.Error(err))
}
if err = m.DataStore.Init(); err != nil {
if err = m.DataClient.Init(); err != nil {
base.Logger().Fatal("failed to init database", zap.Error(err))
}

// connect cache database
m.CacheStore, err = cache.Open(m.GorseConfig.Database.CacheStore)
m.CacheClient, err = cache.Open(m.GorseConfig.Database.CacheStore)
if err != nil {
base.Logger().Fatal("failed to connect cache database", zap.Error(err),
zap.String("database", m.GorseConfig.Database.CacheStore))
Expand Down Expand Up @@ -172,20 +167,20 @@ func (m *Master) FitLoop() {
for {
// download dataset
base.Logger().Info("load dataset for model fit", zap.Strings("feedback_types", m.GorseConfig.Database.PositiveFeedbackType))
dataSet, items, feedbacks, err := ranking.LoadDataFromDatabase(m.DataStore, m.GorseConfig.Database.PositiveFeedbackType,
dataSet, items, feedbacks, err := ranking.LoadDataFromDatabase(m.DataClient, m.GorseConfig.Database.PositiveFeedbackType,
m.GorseConfig.Database.ItemTTL, m.GorseConfig.Database.PositiveFeedbackTTL)
if err != nil {
base.Logger().Error("failed to load database", zap.Error(err))
goto sleep
}
// save stats
if err = m.CacheStore.SetString(cache.GlobalMeta, cache.NumUsers, strconv.Itoa(dataSet.UserCount())); err != nil {
if err = m.CacheClient.SetString(cache.GlobalMeta, cache.NumUsers, strconv.Itoa(dataSet.UserCount())); err != nil {
base.Logger().Error("failed to write meta", zap.Error(err))
}
if err = m.CacheStore.SetString(cache.GlobalMeta, cache.NumItems, strconv.Itoa(dataSet.ItemCount())); err != nil {
if err = m.CacheClient.SetString(cache.GlobalMeta, cache.NumItems, strconv.Itoa(dataSet.ItemCount())); err != nil {
base.Logger().Error("failed to write meta", zap.Error(err))
}
if err = m.CacheStore.SetString(cache.GlobalMeta, cache.NumPositiveFeedback, strconv.Itoa(dataSet.Count())); err != nil {
if err = m.CacheClient.SetString(cache.GlobalMeta, cache.NumPositiveFeedback, strconv.Itoa(dataSet.Count())); err != nil {
base.Logger().Error("failed to write meta", zap.Error(err))
}
// sleep if empty
Expand All @@ -194,33 +189,33 @@ func (m *Master) FitLoop() {
goto sleep
}
// check best model
bestName, bestModel, bestScore = m.prSearcher.GetBestModel()
m.prMutex.Lock()
bestName, bestModel, bestScore = m.rankingModelSearcher.GetBestModel()
m.rankingModelMutex.Lock()
if bestName != "" &&
(bestName != m.prModelName || bestModel.GetParams().ToString() != m.prModel.GetParams().ToString()) &&
(bestScore.NDCG > m.prScore.NDCG) {
(bestName != m.rankingModelName || bestModel.GetParams().ToString() != m.rankingModel.GetParams().ToString()) &&
(bestScore.NDCG > m.rankingScore.NDCG) {
// 1. best model must have been found.
// 2. best model must be different from current model
// 3. best model must perform better than current model
m.prModel = bestModel
m.prModelName = bestName
m.rankingModel = bestModel
m.rankingModelName = bestName
base.Logger().Info("find better model",
zap.String("name", bestName),
zap.Any("params", m.prModel.GetParams()))
zap.Any("params", m.rankingModel.GetParams()))
} else if dataSet.UserCount() == lastNumUsers && dataSet.ItemCount() == lastNumItems && dataSet.Count() == lastNumFeedback {
// sleep if nothing changed
m.prMutex.Unlock()
m.rankingModelMutex.Unlock()
goto sleep
}
m.prMutex.Unlock()
m.rankingModelMutex.Unlock()
lastNumUsers, lastNumItems, lastNumFeedback = dataSet.UserCount(), dataSet.ItemCount(), dataSet.Count()
// update user index
m.userIndexMutex.Lock()
m.userIndex = dataSet.UserIndex
m.userIndexVersion++
m.userIndexMutex.Unlock()
// fit model
m.fitPRModel(dataSet, m.prModel)
m.fitRankingModel(dataSet, m.rankingModel)
// collect similar items
m.similar(items, dataSet, model.SimilarityDot)
// collect popular items
Expand All @@ -233,15 +228,15 @@ func (m *Master) FitLoop() {
}
}

// SearchLoop searches optimal recommendation model in background. It never modifies variables other than prSearcher.
// SearchLoop searches optimal recommendation model in background. It never modifies variables other than rankingModelSearcher.
func (m *Master) SearchLoop() {
defer base.CheckPanic()
lastNumUsers, lastNumItems, lastNumFeedback := 0, 0, 0
for {
var trainSet, valSet *ranking.DataSet
// download dataset
base.Logger().Info("load dataset for model search", zap.Strings("feedback_types", m.GorseConfig.Database.PositiveFeedbackType))
dataSet, _, _, err := ranking.LoadDataFromDatabase(m.DataStore, m.GorseConfig.Database.PositiveFeedbackType,
dataSet, _, _, err := ranking.LoadDataFromDatabase(m.DataClient, m.GorseConfig.Database.PositiveFeedbackType,
m.GorseConfig.Database.ItemTTL, m.GorseConfig.Database.PositiveFeedbackTTL)
if err != nil {
base.Logger().Error("failed to load database", zap.Error(err))
Expand All @@ -258,7 +253,7 @@ func (m *Master) SearchLoop() {
}
// start search
trainSet, valSet = dataSet.Split(0, 0)
err = m.prSearcher.Fit(trainSet, valSet)
err = m.rankingModelSearcher.Fit(trainSet, valSet)
if err != nil {
base.Logger().Error("failed to search model", zap.Error(err))
}
Expand Down
24 changes: 12 additions & 12 deletions master/master_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ func newMockMaster(t *testing.T) *mockMaster {
s.cacheStoreServer, err = miniredis.Run()
assert.Nil(t, err)
// open database
s.DataStore, err = data.Open("redis://" + s.dataStoreServer.Addr())
s.DataClient, err = data.Open("redis://" + s.dataStoreServer.Addr())
assert.Nil(t, err)
s.CacheStore, err = cache.Open("redis://" + s.cacheStoreServer.Addr())
s.CacheClient, err = cache.Open("redis://" + s.cacheStoreServer.Addr())
assert.Nil(t, err)
return s
}
Expand All @@ -76,21 +76,21 @@ func TestMaster_CollectLatest(t *testing.T) {
}
m.latest(items)
// check latest items
latest, err := m.CacheStore.GetScores(cache.LatestItems, "", 0, 100)
latest, err := m.CacheClient.GetScores(cache.LatestItems, "", 0, 100)
assert.Nil(t, err)
assert.Equal(t, []cache.ScoredItem{
{items[9].ItemId, float32(items[9].Timestamp.Unix())},
{items[8].ItemId, float32(items[8].Timestamp.Unix())},
{items[7].ItemId, float32(items[7].Timestamp.Unix())},
}, latest)
latest, err = m.CacheStore.GetScores(cache.LatestItems, "even", 0, 100)
latest, err = m.CacheClient.GetScores(cache.LatestItems, "even", 0, 100)
assert.Nil(t, err)
assert.Equal(t, []cache.ScoredItem{
{items[8].ItemId, float32(items[8].Timestamp.Unix())},
{items[6].ItemId, float32(items[6].Timestamp.Unix())},
{items[4].ItemId, float32(items[4].Timestamp.Unix())},
}, latest)
latest, err = m.CacheStore.GetScores(cache.LatestItems, "odd", 0, 100)
latest, err = m.CacheClient.GetScores(cache.LatestItems, "odd", 0, 100)
assert.Nil(t, err)
assert.Equal(t, []cache.ScoredItem{
{items[9].ItemId, float32(items[9].Timestamp.Unix())},
Expand Down Expand Up @@ -143,21 +143,21 @@ func TestMaster_CollectPopItem(t *testing.T) {
}
m.popItem(items, feedbacks)
// check popular items
popular, err := m.CacheStore.GetScores(cache.PopularItems, "", 0, 100)
popular, err := m.CacheClient.GetScores(cache.PopularItems, "", 0, 100)
assert.Nil(t, err)
assert.Equal(t, []cache.ScoredItem{
{ItemId: items[9].ItemId, Score: 10},
{ItemId: items[8].ItemId, Score: 9},
{ItemId: items[7].ItemId, Score: 8},
}, popular)
popular, err = m.CacheStore.GetScores(cache.PopularItems, "even", 0, 100)
popular, err = m.CacheClient.GetScores(cache.PopularItems, "even", 0, 100)
assert.Nil(t, err)
assert.Equal(t, []cache.ScoredItem{
{ItemId: items[8].ItemId, Score: 9},
{ItemId: items[6].ItemId, Score: 7},
{ItemId: items[4].ItemId, Score: 5},
}, popular)
popular, err = m.CacheStore.GetScores(cache.PopularItems, "odd", 0, 100)
popular, err = m.CacheClient.GetScores(cache.PopularItems, "odd", 0, 100)
assert.Nil(t, err)
assert.Equal(t, []cache.ScoredItem{
{ItemId: items[9].ItemId, Score: 10},
Expand Down Expand Up @@ -201,15 +201,15 @@ func TestMaster_FitCFModel(t *testing.T) {
}
}
var err error
err = m.DataStore.BatchInsertItem(items)
err = m.DataClient.BatchInsertItem(items)
assert.Nil(t, err)
err = m.DataStore.BatchInsertFeedback(feedbacks, true, true)
err = m.DataClient.BatchInsertFeedback(feedbacks, true, true)
assert.Nil(t, err)
dataset, _, _, err := ranking.LoadDataFromDatabase(m.DataStore, []string{"FeedbackType"}, 0, 0)
dataset, _, _, err := ranking.LoadDataFromDatabase(m.DataClient, []string{"FeedbackType"}, 0, 0)
assert.Nil(t, err)
// similar items (common users)
m.similar(items, dataset, model.SimilarityDot)
similar, err := m.CacheStore.GetScores(cache.SimilarItems, "9", 0, 100)
similar, err := m.CacheClient.GetScores(cache.SimilarItems, "9", 0, 100)
assert.Nil(t, err)
assert.Equal(t, []string{"8", "7", "6"}, cache.RemoveScores(similar))
}
Loading

0 comments on commit cda9ce9

Please sign in to comment.