From cda9ce9aeec499db459d428369a90e5a94ae0e23 Mon Sep 17 00:00:00 2001 From: Zhenghao Zhang Date: Mon, 31 May 2021 14:00:31 +0800 Subject: [PATCH] improve cache consistency --- master/local_cache.go | 49 ++++-- master/master.go | 83 +++++---- master/master_test.go | 24 +-- master/recommend.go | 52 +++--- master/rest.go | 51 +++--- master/rest_test.go | 90 +++++----- master/rpc.go | 89 +++++----- protocol/protocol.pb.go | 130 +++++++------- protocol/protocol.proto | 6 +- protocol/protocol_grpc.pb.go | 60 ++----- server/local_cache.go | 3 + server/rest.go | 130 +++++++------- server/rest_test.go | 87 +++++----- server/server.go | 75 ++------ server/swagger.go | 1 + storage/cache/database.go | 17 +- storage/cache/database_test.go | 36 ++-- storage/cache/no_database.go | 12 ++ storage/cache/redis.go | 86 ++++++---- worker/local_cache.go | 3 + worker/worker.go | 304 +++++++++++++++------------------ worker/worker_test.go | 132 ++++++++++++++ 22 files changed, 796 insertions(+), 724 deletions(-) diff --git a/master/local_cache.go b/master/local_cache.go index 08e22f8e7..34836e18a 100644 --- a/master/local_cache.go +++ b/master/local_cache.go @@ -1,3 +1,17 @@ +// Copyright 2021 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package master import ( @@ -8,15 +22,17 @@ import ( "path/filepath" ) +// LocalCache is local cache for the master node. type LocalCache struct { - path string - ModelName string - ModelVersion int64 - Model ranking.Model - ModelScore ranking.Score - UserIndex base.Index + path string + RankingModelName string + RankingModelVersion int64 + RankingModel ranking.Model + RankingScore ranking.Score + UserIndex base.Index } +// LoadLocalCache loads local cache from a file. func LoadLocalCache(path string) (*LocalCache, error) { state := &LocalCache{path: path} // check if file exists @@ -34,27 +50,27 @@ func LoadLocalCache(path string) (*LocalCache, error) { } decoder := gob.NewDecoder(f) // 1. model name - err = decoder.Decode(&state.ModelName) + err = decoder.Decode(&state.RankingModelName) if err != nil { return state, err } // 2. model version - err = decoder.Decode(&state.ModelVersion) + err = decoder.Decode(&state.RankingModelVersion) if err != nil { return state, err } // 3. model - state.Model, err = ranking.NewModel(state.ModelName, nil) + state.RankingModel, err = ranking.NewModel(state.RankingModelName, nil) if err != nil { return state, err } - err = decoder.Decode(state.Model) + err = decoder.Decode(state.RankingModel) if err != nil { return state, err } - state.Model.SetParams(state.Model.GetParams()) + state.RankingModel.SetParams(state.RankingModel.GetParams()) // 4. model score - err = decoder.Decode(&state.ModelScore) + err = decoder.Decode(&state.RankingScore) if err != nil { return state, err } @@ -67,6 +83,7 @@ func LoadLocalCache(path string) (*LocalCache, error) { return state, nil } +// WriteLocalCache writes local cache to a file. func (c *LocalCache) WriteLocalCache() error { // create parent folder if not exists parent := filepath.Dir(c.path) @@ -83,22 +100,22 @@ func (c *LocalCache) WriteLocalCache() error { } encoder := gob.NewEncoder(f) // 1. model name - err = encoder.Encode(c.ModelName) + err = encoder.Encode(c.RankingModelName) if err != nil { return err } // 2. model version - err = encoder.Encode(c.ModelVersion) + err = encoder.Encode(c.RankingModelVersion) if err != nil { return err } // 3. model - err = encoder.Encode(c.Model) + err = encoder.Encode(c.RankingModel) if err != nil { return err } // 4. model score - err = encoder.Encode(c.ModelScore) + err = encoder.Encode(c.RankingScore) if err != nil { return err } diff --git a/master/master.go b/master/master.go index f43e75a21..855c365c0 100644 --- a/master/master.go +++ b/master/master.go @@ -58,17 +58,12 @@ type Master struct { userIndexMutex sync.Mutex // personal ranking model - prModel ranking.Model - prModelName string - prVersion int64 - prScore ranking.Score - prMutex sync.Mutex - prSearcher *ranking.ModelSearcher - - // factorization machine - //fmModel ctr.FactorizationMachine - //ctrVersion int64 - //fmMutex sync.mutex + rankingModel ranking.Model + rankingModelName string + rankingModelVersion int64 + rankingScore ranking.Score + rankingModelMutex sync.Mutex + rankingModelSearcher *ranking.ModelSearcher localCache *LocalCache } @@ -78,13 +73,13 @@ func NewMaster(cfg *config.Config) *Master { return &Master{ nodesInfo: make(map[string]*Node), // init versions - prVersion: rand.Int63(), + rankingModelVersion: rand.Int63(), // ctrVersion: rand.Int63(), userIndexVersion: rand.Int63(), // default model - prModelName: "bpr", - prModel: ranking.NewBPR(nil), - prSearcher: ranking.NewModelSearcher(cfg.Recommend.SearchEpoch, cfg.Recommend.SearchTrials), + rankingModelName: "bpr", + rankingModel: ranking.NewBPR(nil), + rankingModelSearcher: ranking.NewModelSearcher(cfg.Recommend.SearchEpoch, cfg.Recommend.SearchTrials), RestServer: server.RestServer{ GorseConfig: cfg, HttpHost: cfg.Master.HttpHost, @@ -103,16 +98,16 @@ func (m *Master) Serve() { if err != nil { base.Logger().Error("failed to load local cache", zap.Error(err)) } - if m.localCache.Model != nil { + if m.localCache.RankingModel != nil { base.Logger().Info("load cached model", - zap.String("model_name", m.localCache.ModelName), - zap.String("model_version", base.Hex(m.localCache.ModelVersion)), - zap.Float32("model_score", m.localCache.ModelScore.NDCG), - zap.Any("params", m.localCache.Model.GetParams())) - m.prModel = m.localCache.Model - m.prModelName = m.localCache.ModelName - m.prVersion = m.localCache.ModelVersion - m.prScore = m.localCache.ModelScore + zap.String("model_name", m.localCache.RankingModelName), + zap.String("model_version", base.Hex(m.localCache.RankingModelVersion)), + zap.Float32("model_score", m.localCache.RankingScore.NDCG), + zap.Any("params", m.localCache.RankingModel.GetParams())) + m.rankingModel = m.localCache.RankingModel + m.rankingModelName = m.localCache.RankingModelName + m.rankingModelVersion = m.localCache.RankingModelVersion + m.rankingScore = m.localCache.RankingScore } // create cluster meta cache @@ -126,16 +121,16 @@ func (m *Master) Serve() { } // connect data database - m.DataStore, err = data.Open(m.GorseConfig.Database.DataStore) + m.DataClient, err = data.Open(m.GorseConfig.Database.DataStore) if err != nil { base.Logger().Fatal("failed to connect data database", zap.Error(err)) } - if err = m.DataStore.Init(); err != nil { + if err = m.DataClient.Init(); err != nil { base.Logger().Fatal("failed to init database", zap.Error(err)) } // connect cache database - m.CacheStore, err = cache.Open(m.GorseConfig.Database.CacheStore) + m.CacheClient, err = cache.Open(m.GorseConfig.Database.CacheStore) if err != nil { base.Logger().Fatal("failed to connect cache database", zap.Error(err), zap.String("database", m.GorseConfig.Database.CacheStore)) @@ -172,20 +167,20 @@ func (m *Master) FitLoop() { for { // download dataset base.Logger().Info("load dataset for model fit", zap.Strings("feedback_types", m.GorseConfig.Database.PositiveFeedbackType)) - dataSet, items, feedbacks, err := ranking.LoadDataFromDatabase(m.DataStore, m.GorseConfig.Database.PositiveFeedbackType, + dataSet, items, feedbacks, err := ranking.LoadDataFromDatabase(m.DataClient, m.GorseConfig.Database.PositiveFeedbackType, m.GorseConfig.Database.ItemTTL, m.GorseConfig.Database.PositiveFeedbackTTL) if err != nil { base.Logger().Error("failed to load database", zap.Error(err)) goto sleep } // save stats - if err = m.CacheStore.SetString(cache.GlobalMeta, cache.NumUsers, strconv.Itoa(dataSet.UserCount())); err != nil { + if err = m.CacheClient.SetString(cache.GlobalMeta, cache.NumUsers, strconv.Itoa(dataSet.UserCount())); err != nil { base.Logger().Error("failed to write meta", zap.Error(err)) } - if err = m.CacheStore.SetString(cache.GlobalMeta, cache.NumItems, strconv.Itoa(dataSet.ItemCount())); err != nil { + if err = m.CacheClient.SetString(cache.GlobalMeta, cache.NumItems, strconv.Itoa(dataSet.ItemCount())); err != nil { base.Logger().Error("failed to write meta", zap.Error(err)) } - if err = m.CacheStore.SetString(cache.GlobalMeta, cache.NumPositiveFeedback, strconv.Itoa(dataSet.Count())); err != nil { + if err = m.CacheClient.SetString(cache.GlobalMeta, cache.NumPositiveFeedback, strconv.Itoa(dataSet.Count())); err != nil { base.Logger().Error("failed to write meta", zap.Error(err)) } // sleep if empty @@ -194,25 +189,25 @@ func (m *Master) FitLoop() { goto sleep } // check best model - bestName, bestModel, bestScore = m.prSearcher.GetBestModel() - m.prMutex.Lock() + bestName, bestModel, bestScore = m.rankingModelSearcher.GetBestModel() + m.rankingModelMutex.Lock() if bestName != "" && - (bestName != m.prModelName || bestModel.GetParams().ToString() != m.prModel.GetParams().ToString()) && - (bestScore.NDCG > m.prScore.NDCG) { + (bestName != m.rankingModelName || bestModel.GetParams().ToString() != m.rankingModel.GetParams().ToString()) && + (bestScore.NDCG > m.rankingScore.NDCG) { // 1. best model must have been found. // 2. best model must be different from current model // 3. best model must perform better than current model - m.prModel = bestModel - m.prModelName = bestName + m.rankingModel = bestModel + m.rankingModelName = bestName base.Logger().Info("find better model", zap.String("name", bestName), - zap.Any("params", m.prModel.GetParams())) + zap.Any("params", m.rankingModel.GetParams())) } else if dataSet.UserCount() == lastNumUsers && dataSet.ItemCount() == lastNumItems && dataSet.Count() == lastNumFeedback { // sleep if nothing changed - m.prMutex.Unlock() + m.rankingModelMutex.Unlock() goto sleep } - m.prMutex.Unlock() + m.rankingModelMutex.Unlock() lastNumUsers, lastNumItems, lastNumFeedback = dataSet.UserCount(), dataSet.ItemCount(), dataSet.Count() // update user index m.userIndexMutex.Lock() @@ -220,7 +215,7 @@ func (m *Master) FitLoop() { m.userIndexVersion++ m.userIndexMutex.Unlock() // fit model - m.fitPRModel(dataSet, m.prModel) + m.fitRankingModel(dataSet, m.rankingModel) // collect similar items m.similar(items, dataSet, model.SimilarityDot) // collect popular items @@ -233,7 +228,7 @@ func (m *Master) FitLoop() { } } -// SearchLoop searches optimal recommendation model in background. It never modifies variables other than prSearcher. +// SearchLoop searches optimal recommendation model in background. It never modifies variables other than rankingModelSearcher. func (m *Master) SearchLoop() { defer base.CheckPanic() lastNumUsers, lastNumItems, lastNumFeedback := 0, 0, 0 @@ -241,7 +236,7 @@ func (m *Master) SearchLoop() { var trainSet, valSet *ranking.DataSet // download dataset base.Logger().Info("load dataset for model search", zap.Strings("feedback_types", m.GorseConfig.Database.PositiveFeedbackType)) - dataSet, _, _, err := ranking.LoadDataFromDatabase(m.DataStore, m.GorseConfig.Database.PositiveFeedbackType, + dataSet, _, _, err := ranking.LoadDataFromDatabase(m.DataClient, m.GorseConfig.Database.PositiveFeedbackType, m.GorseConfig.Database.ItemTTL, m.GorseConfig.Database.PositiveFeedbackTTL) if err != nil { base.Logger().Error("failed to load database", zap.Error(err)) @@ -258,7 +253,7 @@ func (m *Master) SearchLoop() { } // start search trainSet, valSet = dataSet.Split(0, 0) - err = m.prSearcher.Fit(trainSet, valSet) + err = m.rankingModelSearcher.Fit(trainSet, valSet) if err != nil { base.Logger().Error("failed to search model", zap.Error(err)) } diff --git a/master/master_test.go b/master/master_test.go index 2bd04449b..30f7ceb5f 100644 --- a/master/master_test.go +++ b/master/master_test.go @@ -47,9 +47,9 @@ func newMockMaster(t *testing.T) *mockMaster { s.cacheStoreServer, err = miniredis.Run() assert.Nil(t, err) // open database - s.DataStore, err = data.Open("redis://" + s.dataStoreServer.Addr()) + s.DataClient, err = data.Open("redis://" + s.dataStoreServer.Addr()) assert.Nil(t, err) - s.CacheStore, err = cache.Open("redis://" + s.cacheStoreServer.Addr()) + s.CacheClient, err = cache.Open("redis://" + s.cacheStoreServer.Addr()) assert.Nil(t, err) return s } @@ -76,21 +76,21 @@ func TestMaster_CollectLatest(t *testing.T) { } m.latest(items) // check latest items - latest, err := m.CacheStore.GetScores(cache.LatestItems, "", 0, 100) + latest, err := m.CacheClient.GetScores(cache.LatestItems, "", 0, 100) assert.Nil(t, err) assert.Equal(t, []cache.ScoredItem{ {items[9].ItemId, float32(items[9].Timestamp.Unix())}, {items[8].ItemId, float32(items[8].Timestamp.Unix())}, {items[7].ItemId, float32(items[7].Timestamp.Unix())}, }, latest) - latest, err = m.CacheStore.GetScores(cache.LatestItems, "even", 0, 100) + latest, err = m.CacheClient.GetScores(cache.LatestItems, "even", 0, 100) assert.Nil(t, err) assert.Equal(t, []cache.ScoredItem{ {items[8].ItemId, float32(items[8].Timestamp.Unix())}, {items[6].ItemId, float32(items[6].Timestamp.Unix())}, {items[4].ItemId, float32(items[4].Timestamp.Unix())}, }, latest) - latest, err = m.CacheStore.GetScores(cache.LatestItems, "odd", 0, 100) + latest, err = m.CacheClient.GetScores(cache.LatestItems, "odd", 0, 100) assert.Nil(t, err) assert.Equal(t, []cache.ScoredItem{ {items[9].ItemId, float32(items[9].Timestamp.Unix())}, @@ -143,21 +143,21 @@ func TestMaster_CollectPopItem(t *testing.T) { } m.popItem(items, feedbacks) // check popular items - popular, err := m.CacheStore.GetScores(cache.PopularItems, "", 0, 100) + popular, err := m.CacheClient.GetScores(cache.PopularItems, "", 0, 100) assert.Nil(t, err) assert.Equal(t, []cache.ScoredItem{ {ItemId: items[9].ItemId, Score: 10}, {ItemId: items[8].ItemId, Score: 9}, {ItemId: items[7].ItemId, Score: 8}, }, popular) - popular, err = m.CacheStore.GetScores(cache.PopularItems, "even", 0, 100) + popular, err = m.CacheClient.GetScores(cache.PopularItems, "even", 0, 100) assert.Nil(t, err) assert.Equal(t, []cache.ScoredItem{ {ItemId: items[8].ItemId, Score: 9}, {ItemId: items[6].ItemId, Score: 7}, {ItemId: items[4].ItemId, Score: 5}, }, popular) - popular, err = m.CacheStore.GetScores(cache.PopularItems, "odd", 0, 100) + popular, err = m.CacheClient.GetScores(cache.PopularItems, "odd", 0, 100) assert.Nil(t, err) assert.Equal(t, []cache.ScoredItem{ {ItemId: items[9].ItemId, Score: 10}, @@ -201,15 +201,15 @@ func TestMaster_FitCFModel(t *testing.T) { } } var err error - err = m.DataStore.BatchInsertItem(items) + err = m.DataClient.BatchInsertItem(items) assert.Nil(t, err) - err = m.DataStore.BatchInsertFeedback(feedbacks, true, true) + err = m.DataClient.BatchInsertFeedback(feedbacks, true, true) assert.Nil(t, err) - dataset, _, _, err := ranking.LoadDataFromDatabase(m.DataStore, []string{"FeedbackType"}, 0, 0) + dataset, _, _, err := ranking.LoadDataFromDatabase(m.DataClient, []string{"FeedbackType"}, 0, 0) assert.Nil(t, err) // similar items (common users) m.similar(items, dataset, model.SimilarityDot) - similar, err := m.CacheStore.GetScores(cache.SimilarItems, "9", 0, 100) + similar, err := m.CacheClient.GetScores(cache.SimilarItems, "9", 0, 100) assert.Nil(t, err) assert.Equal(t, []string{"8", "7", "6"}, cache.RemoveScores(similar)) } diff --git a/master/recommend.go b/master/recommend.go index fe74f2f28..28b7b3dcf 100644 --- a/master/recommend.go +++ b/master/recommend.go @@ -46,11 +46,11 @@ func (m *Master) popItem(items []data.Item, feedback []data.Feedback) { // write back for label, topItems := range popItems { result, scores := topItems.PopAll() - if err := m.CacheStore.SetScores(cache.PopularItems, label, cache.CreateScoredItems(result, scores)); err != nil { + if err := m.CacheClient.SetScores(cache.PopularItems, label, cache.CreateScoredItems(result, scores)); err != nil { base.Logger().Error("failed to cache popular items", zap.Error(err)) } } - if err := m.CacheStore.SetString(cache.GlobalMeta, cache.CollectPopularTime, base.Now()); err != nil { + if err := m.CacheClient.SetString(cache.GlobalMeta, cache.CollectPopularTime, base.Now()); err != nil { base.Logger().Error("failed to cache popular items", zap.Error(err)) } } @@ -73,11 +73,11 @@ func (m *Master) latest(items []data.Item) { } for label, topItems := range latestItems { result, scores := topItems.PopAll() - if err = m.CacheStore.SetScores(cache.LatestItems, label, cache.CreateScoredItems(result, scores)); err != nil { + if err = m.CacheClient.SetScores(cache.LatestItems, label, cache.CreateScoredItems(result, scores)); err != nil { base.Logger().Error("failed to cache latest items", zap.Error(err)) } } - if err = m.CacheStore.SetString(cache.GlobalMeta, cache.CollectLatestTime, base.Now()); err != nil { + if err = m.CacheClient.SetString(cache.GlobalMeta, cache.CollectLatestTime, base.Now()); err != nil { base.Logger().Error("failed to cache latest items time", zap.Error(err)) } } @@ -135,7 +135,7 @@ func (m *Master) similar(items []data.Item, dataset *ranking.DataSet, similarity for i := range recommends { recommends[i] = dataset.ItemIndex.ToName(elem[i]) } - if err := m.CacheStore.SetScores(cache.SimilarItems, dataset.ItemIndex.ToName(jobId), cache.CreateScoredItems(recommends, scores)); err != nil { + if err := m.CacheClient.SetScores(cache.SimilarItems, dataset.ItemIndex.ToName(jobId), cache.CreateScoredItems(recommends, scores)); err != nil { return err } completed <- nil @@ -144,7 +144,7 @@ func (m *Master) similar(items []data.Item, dataset *ranking.DataSet, similarity base.Logger().Error("failed to cache similar items", zap.Error(err)) } close(completed) - if err := m.CacheStore.SetString(cache.GlobalMeta, cache.CollectSimilarTime, base.Now()); err != nil { + if err := m.CacheClient.SetString(cache.GlobalMeta, cache.CollectSimilarTime, base.Now()); err != nil { base.Logger().Error("failed to cache similar items", zap.Error(err)) } } @@ -181,47 +181,47 @@ func dotInt(a, b []int) float32 { return sum } -func (m *Master) fitPRModel(dataSet *ranking.DataSet, prModel ranking.Model) { +func (m *Master) fitRankingModel(dataSet *ranking.DataSet, prModel ranking.Model) { base.Logger().Info("fit personal ranking model", zap.Int("n_jobs", m.GorseConfig.Master.FitJobs)) // training model trainSet, testSet := dataSet.Split(0, 0) score := prModel.Fit(trainSet, testSet, nil) // update match model - m.prMutex.Lock() - m.prModel = prModel - m.prVersion++ - m.prScore = score - m.prMutex.Unlock() + m.rankingModelMutex.Lock() + m.rankingModel = prModel + m.rankingModelVersion++ + m.rankingScore = score + m.rankingModelMutex.Unlock() base.Logger().Info("fit personal ranking model complete", - zap.String("version", fmt.Sprintf("%x", m.prVersion))) - if err := m.DataStore.InsertMeasurement(data.Measurement{Name: "NDCG@10", Value: score.NDCG, Timestamp: time.Now()}); err != nil { + zap.String("version", fmt.Sprintf("%x", m.rankingModelVersion))) + if err := m.DataClient.InsertMeasurement(data.Measurement{Name: "NDCG@10", Value: score.NDCG, Timestamp: time.Now()}); err != nil { base.Logger().Error("failed to insert measurement", zap.Error(err)) } - if err := m.DataStore.InsertMeasurement(data.Measurement{Name: "Recall@10", Value: score.Recall, Timestamp: time.Now()}); err != nil { + if err := m.DataClient.InsertMeasurement(data.Measurement{Name: "Recall@10", Value: score.Recall, Timestamp: time.Now()}); err != nil { base.Logger().Error("failed to insert measurement", zap.Error(err)) } - if err := m.DataStore.InsertMeasurement(data.Measurement{Name: "Precision@10", Value: score.Precision, Timestamp: time.Now()}); err != nil { + if err := m.DataClient.InsertMeasurement(data.Measurement{Name: "Precision@10", Value: score.Precision, Timestamp: time.Now()}); err != nil { base.Logger().Error("failed to insert measurement", zap.Error(err)) } - if err := m.CacheStore.SetString(cache.GlobalMeta, cache.FitMatrixFactorizationTime, base.Now()); err != nil { + if err := m.CacheClient.SetString(cache.GlobalMeta, cache.FitMatrixFactorizationTime, base.Now()); err != nil { base.Logger().Error("failed to write meta", zap.Error(err)) } - if err := m.CacheStore.SetString(cache.GlobalMeta, cache.MatrixFactorizationVersion, fmt.Sprintf("%x", m.prVersion)); err != nil { + if err := m.CacheClient.SetString(cache.GlobalMeta, cache.MatrixFactorizationVersion, fmt.Sprintf("%x", m.rankingModelVersion)); err != nil { base.Logger().Error("failed to write meta", zap.Error(err)) } // caching model - m.localCache.ModelName = m.prModelName - m.localCache.ModelVersion = m.prVersion - m.localCache.Model = prModel - m.localCache.ModelScore = score + m.localCache.RankingModelName = m.rankingModelName + m.localCache.RankingModelVersion = m.rankingModelVersion + m.localCache.RankingModel = prModel + m.localCache.RankingScore = score m.localCache.UserIndex = m.userIndex if err := m.localCache.WriteLocalCache(); err != nil { base.Logger().Error("failed to write local cache", zap.Error(err)) } else { base.Logger().Info("write model to local cache", - zap.String("model_name", m.localCache.ModelName), - zap.String("model_version", base.Hex(m.localCache.ModelVersion)), - zap.Float32("model_score", m.localCache.ModelScore.NDCG), - zap.Any("params", m.localCache.Model.GetParams())) + zap.String("model_name", m.localCache.RankingModelName), + zap.String("model_version", base.Hex(m.localCache.RankingModelVersion)), + zap.Float32("model_score", m.localCache.RankingScore.NDCG), + zap.Any("params", m.localCache.RankingModel.GetParams())) } } diff --git a/master/rest.go b/master/rest.go index c0a3ad164..85539d3a3 100644 --- a/master/rest.go +++ b/master/rest.go @@ -159,32 +159,32 @@ type Status struct { NumUsers string NumItems string NumPosFeedback string - PRModel string - CTRModel string + RankingModel string + ClickModel string } func (m *Master) getStats(request *restful.Request, response *restful.Response) { status := Status{} var err error // read number of users - status.NumUsers, err = m.CacheStore.GetString(cache.GlobalMeta, cache.NumUsers) + status.NumUsers, err = m.CacheClient.GetString(cache.GlobalMeta, cache.NumUsers) if err != nil && err != cache.ErrObjectNotExist { server.InternalServerError(response, err) return } // read number of items - status.NumItems, err = m.CacheStore.GetString(cache.GlobalMeta, cache.NumItems) + status.NumItems, err = m.CacheClient.GetString(cache.GlobalMeta, cache.NumItems) if err != nil && err != cache.ErrObjectNotExist { server.InternalServerError(response, err) return } // read number of positive feedback - status.NumPosFeedback, err = m.CacheStore.GetString(cache.GlobalMeta, cache.NumPositiveFeedback) + status.NumPosFeedback, err = m.CacheClient.GetString(cache.GlobalMeta, cache.NumPositiveFeedback) if err != nil && err != cache.ErrObjectNotExist { server.InternalServerError(response, err) return } - status.PRModel = m.prModelName + status.RankingModel = m.rankingModelName server.Ok(response, status) } @@ -203,7 +203,7 @@ func (m *Master) getUser(request *restful.Request, response *restful.Response) { // get user id userId := request.PathParameter("user-id") // get user - user, err := m.DataStore.GetUser(userId) + user, err := m.DataClient.GetUser(userId) if err != nil { if err == data.ErrUserNotExist { server.PageNotFound(response, err) @@ -213,11 +213,11 @@ func (m *Master) getUser(request *restful.Request, response *restful.Response) { return } detail := User{User: user} - if detail.LastActiveTime, err = m.CacheStore.GetString(cache.LastActiveTime, user.UserId); err != nil && err != cache.ErrObjectNotExist { + if detail.LastActiveTime, err = m.CacheClient.GetString(cache.LastActiveTime, user.UserId); err != nil && err != cache.ErrObjectNotExist { server.InternalServerError(response, err) return } - if detail.LastUpdateTime, err = m.CacheStore.GetString(cache.LastUpdateRecommendTime, user.UserId); err != nil && err != cache.ErrObjectNotExist { + if detail.LastUpdateTime, err = m.CacheClient.GetString(cache.LastUpdateRecommendTime, user.UserId); err != nil && err != cache.ErrObjectNotExist { server.InternalServerError(response, err) return } @@ -233,7 +233,7 @@ func (m *Master) getUsers(request *restful.Request, response *restful.Response) return } // get all users - cursor, users, err := m.DataStore.GetUsers(cursor, n) + cursor, users, err := m.DataClient.GetUsers(cursor, n) if err != nil { server.InternalServerError(response, err) return @@ -241,11 +241,11 @@ func (m *Master) getUsers(request *restful.Request, response *restful.Response) details := make([]User, len(users)) for i, user := range users { details[i].User = user - if details[i].LastActiveTime, err = m.CacheStore.GetString(cache.LastActiveTime, user.UserId); err != nil && err != cache.ErrObjectNotExist { + if details[i].LastActiveTime, err = m.CacheClient.GetString(cache.LastActiveTime, user.UserId); err != nil && err != cache.ErrObjectNotExist { server.InternalServerError(response, err) return } - if details[i].LastUpdateTime, err = m.CacheStore.GetString(cache.LastUpdateRecommendTime, user.UserId); err != nil && err != cache.ErrObjectNotExist { + if details[i].LastUpdateTime, err = m.CacheClient.GetString(cache.LastUpdateRecommendTime, user.UserId); err != nil && err != cache.ErrObjectNotExist { server.InternalServerError(response, err) return } @@ -269,7 +269,7 @@ func (m *Master) getRecommend(request *restful.Request, response *restful.Respon // Send result details := make([]data.Item, len(results)) for i := range results { - details[i], err = m.DataStore.GetItem(results[i]) + details[i], err = m.DataClient.GetItem(results[i]) if err != nil { server.InternalServerError(response, err) return @@ -290,7 +290,7 @@ type Feedback struct { func (m *Master) getTypedFeedbackByUser(request *restful.Request, response *restful.Response) { feedbackType := request.PathParameter("feedback-type") userId := request.PathParameter("user-id") - feedback, err := m.DataStore.GetUserFeedback(userId, feedbackType) + feedback, err := m.DataClient.GetUserFeedback(userId, feedbackType) if err != nil { server.InternalServerError(response, err) return @@ -301,7 +301,7 @@ func (m *Master) getTypedFeedbackByUser(request *restful.Request, response *rest details[i].UserId = feedback[i].UserId details[i].Timestamp = feedback[i].Timestamp details[i].Comment = feedback[i].Comment - details[i].Item, err = m.DataStore.GetItem(feedback[i].ItemId) + details[i].Item, err = m.DataClient.GetItem(feedback[i].ItemId) if err != nil { server.InternalServerError(response, err) return @@ -322,7 +322,7 @@ func (m *Master) getList(prefix string, name string, request *restful.Request, r return } // Get the popular list - items, err := m.CacheStore.GetScores(prefix, name, begin, end) + items, err := m.CacheClient.GetScores(prefix, name, begin, end) if err != nil { server.InternalServerError(response, err) return @@ -330,7 +330,7 @@ func (m *Master) getList(prefix string, name string, request *restful.Request, r // Send result details := make([]data.Item, len(items)) for i := range items { - details[i], err = m.DataStore.GetItem(items[i].ItemId) + details[i], err = m.DataClient.GetItem(items[i].ItemId) if err != nil { server.InternalServerError(response, err) return @@ -369,7 +369,7 @@ func (m *Master) importExportItems(response http.ResponseWriter, request *http.R const batchSize = 1024 for { var items []data.Item - cursor, items, err = m.DataStore.GetItems(cursor, batchSize, nil) + cursor, items, err = m.DataClient.GetItems(cursor, batchSize, nil) if err != nil { server.InternalServerError(restful.NewResponse(response), err) return @@ -451,7 +451,7 @@ func (m *Master) importItems(response http.ResponseWriter, file io.Reader, hasHe // 4. comment item.Comment = splits[3] items = append(items, item) - //err = m.DataStore.InsertItem(item) + //err = m.DataClient.InsertItem(item) //if err != nil { // server.InternalServerError(restful.NewResponse(response), err) // return false @@ -463,7 +463,7 @@ func (m *Master) importItems(response http.ResponseWriter, file io.Reader, hasHe server.BadRequest(restful.NewResponse(response), err) return } - err = m.DataStore.BatchInsertItem(items) + err = m.DataClient.BatchInsertItem(items) if err != nil { server.InternalServerError(restful.NewResponse(response), err) return @@ -520,7 +520,7 @@ func (m *Master) importExportFeedback(response http.ResponseWriter, request *htt const batchSize = 1024 for { var feedback []data.Feedback - cursor, feedback, err = m.DataStore.GetFeedback(cursor, batchSize, nil) + cursor, feedback, err = m.DataClient.GetFeedback(cursor, batchSize, nil) if err != nil { server.InternalServerError(restful.NewResponse(response), err) return @@ -609,13 +609,20 @@ func (m *Master) importFeedback(response http.ResponseWriter, file io.Reader, ha server.BadRequest(restful.NewResponse(response), err) return } - err = m.DataStore.BatchInsertFeedback(feedbacks, + // insert to data store + err = m.DataClient.BatchInsertFeedback(feedbacks, m.GorseConfig.Database.AutoInsertUser, m.GorseConfig.Database.AutoInsertItem) if err != nil { server.InternalServerError(restful.NewResponse(response), err) return } + // insert to cache store + err = m.InsertFeedbackToCache(feedbacks) + if err != nil { + server.InternalServerError(restful.NewResponse(response), err) + return + } timeUsed := time.Since(timeStart) base.Logger().Info("complete import feedback", zap.Duration("time_used", timeUsed), diff --git a/master/rest_test.go b/master/rest_test.go index 453dd9602..f2e700c91 100644 --- a/master/rest_test.go +++ b/master/rest_test.go @@ -35,10 +35,8 @@ import ( type mockServer struct { dataStoreServer *miniredis.Miniredis cacheStoreServer *miniredis.Miniredis - dataStoreClient data.Database - cacheStoreClient cache.Database handler *restful.Container - master *Master + Master } func newMockServer(t *testing.T) *mockServer { @@ -50,29 +48,27 @@ func newMockServer(t *testing.T) *mockServer { s.cacheStoreServer, err = miniredis.Run() assert.Nil(t, err) // open database - s.dataStoreClient, err = data.Open("redis://" + s.dataStoreServer.Addr()) + s.DataClient, err = data.Open("redis://" + s.dataStoreServer.Addr()) assert.Nil(t, err) - s.cacheStoreClient, err = cache.Open("redis://" + s.cacheStoreServer.Addr()) + s.CacheClient, err = cache.Open("redis://" + s.cacheStoreServer.Addr()) assert.Nil(t, err) // create server - s.master = &Master{ - RestServer: server.RestServer{ - DataStore: s.dataStoreClient, - CacheStore: s.cacheStoreClient, - GorseConfig: (*config.Config)(nil).LoadDefaultIfNil(), - }, - } - s.master.WebService = new(restful.WebService) - s.master.CreateWebService() + s.GorseConfig = (*config.Config)(nil).LoadDefaultIfNil() + s.WebService = new(restful.WebService) + s.CreateWebService() // create handler s.handler = restful.NewContainer() - s.handler.Add(s.master.WebService) + s.handler.Add(s.WebService) return s } func (s *mockServer) Close(t *testing.T) { - err := s.dataStoreClient.Close() + err := s.DataClient.Close() + assert.Nil(t, err) + err = s.CacheClient.Close() assert.Nil(t, err) + s.dataStoreServer.Close() + s.cacheStoreServer.Close() } func marshal(t *testing.T, v interface{}) string { @@ -90,12 +86,12 @@ func TestMaster_ExportItems(t *testing.T) { {"2", time.Date(2021, 1, 1, 1, 1, 1, 1, time.UTC), []string{"b", "c"}, "t\r\nw\r\no"}, {"3", time.Date(2022, 1, 1, 1, 1, 1, 1, time.UTC), []string{"c", "d"}, "\"three\""}, } - err := s.dataStoreClient.BatchInsertItem(items) + err := s.DataClient.BatchInsertItem(items) assert.Nil(t, err) // send request req := httptest.NewRequest("GET", "https://example.com/", nil) w := httptest.NewRecorder() - s.master.importExportItems(w, req) + s.importExportItems(w, req) assert.Equal(t, http.StatusOK, w.Result().StatusCode) assert.Equal(t, "text/csv", w.Header().Get("Content-Type")) assert.Equal(t, "attachment;filename=items.csv", w.Header().Get("Content-Disposition")) @@ -114,12 +110,12 @@ func TestMaster_ExportFeedback(t *testing.T) { {FeedbackKey: data.FeedbackKey{FeedbackType: "share", UserId: "1", ItemId: "4"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "read", UserId: "2", ItemId: "6"}}, } - err := s.dataStoreClient.BatchInsertFeedback(feedbacks, true, true) + err := s.DataClient.BatchInsertFeedback(feedbacks, true, true) assert.Nil(t, err) // send request req := httptest.NewRequest("GET", "https://example.com/", nil) w := httptest.NewRecorder() - s.master.importExportFeedback(w, req) + s.importExportFeedback(w, req) assert.Equal(t, http.StatusOK, w.Result().StatusCode) assert.Equal(t, "text/csv", w.Header().Get("Content-Type")) assert.Equal(t, "attachment;filename=feedback.csv", w.Header().Get("Content-Disposition")) @@ -154,11 +150,11 @@ func TestMaster_ImportItems(t *testing.T) { req := httptest.NewRequest("POST", "https://example.com/", buf) req.Header.Set("Content-Type", writer.FormDataContentType()) w := httptest.NewRecorder() - s.master.importExportItems(w, req) + s.importExportItems(w, req) // check assert.Equal(t, http.StatusOK, w.Result().StatusCode) assert.JSONEq(t, marshal(t, server.Success{RowAffected: 3}), w.Body.String()) - _, items, err := s.dataStoreClient.GetItems("", 100, nil) + _, items, err := s.DataClient.GetItems("", 100, nil) assert.Nil(t, err) assert.Equal(t, []data.Item{ {"1", time.Date(2020, 1, 1, 1, 1, 1, 1, time.UTC), []string{"a", "b"}, "o,n,e"}, @@ -185,11 +181,11 @@ func TestMaster_ImportItems_DefaultFormat(t *testing.T) { req := httptest.NewRequest("POST", "https://example.com/", buf) req.Header.Set("Content-Type", writer.FormDataContentType()) w := httptest.NewRecorder() - s.master.importExportItems(w, req) + s.importExportItems(w, req) // check assert.Equal(t, http.StatusOK, w.Result().StatusCode) assert.JSONEq(t, marshal(t, server.Success{RowAffected: 3}), w.Body.String()) - _, items, err := s.dataStoreClient.GetItems("", 100, nil) + _, items, err := s.DataClient.GetItems("", 100, nil) assert.Nil(t, err) assert.Equal(t, []data.Item{ {"1", time.Date(2020, 1, 1, 1, 1, 1, 1, time.UTC), []string{"a", "b"}, "one"}, @@ -221,11 +217,11 @@ func TestMaster_ImportFeedback(t *testing.T) { req := httptest.NewRequest("POST", "https://example.com/", buf) req.Header.Set("Content-Type", writer.FormDataContentType()) w := httptest.NewRecorder() - s.master.importExportFeedback(w, req) + s.importExportFeedback(w, req) // check assert.Equal(t, http.StatusOK, w.Result().StatusCode) assert.JSONEq(t, marshal(t, server.Success{RowAffected: 3}), w.Body.String()) - _, feedback, err := s.dataStoreClient.GetFeedback("", 100, nil) + _, feedback, err := s.DataClient.GetFeedback("", 100, nil) assert.Nil(t, err) assert.Equal(t, []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "2"}}, @@ -252,11 +248,11 @@ func TestMaster_ImportFeedback_Default(t *testing.T) { req := httptest.NewRequest("POST", "https://example.com/", buf) req.Header.Set("Content-Type", writer.FormDataContentType()) w := httptest.NewRecorder() - s.master.importExportFeedback(w, req) + s.importExportFeedback(w, req) // check assert.Equal(t, http.StatusOK, w.Result().StatusCode) assert.JSONEq(t, marshal(t, server.Success{RowAffected: 3}), w.Body.String()) - _, feedback, err := s.dataStoreClient.GetFeedback("", 100, nil) + _, feedback, err := s.DataClient.GetFeedback("", 100, nil) assert.Nil(t, err) assert.Equal(t, []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "2"}}, @@ -271,9 +267,9 @@ func TestMaster_GetCluster(t *testing.T) { // add nodes serverNode := &Node{"alan turnin", ServerNode, "192.168.1.100", 1080} workerNode := &Node{"dennis ritchie", WorkerNode, "192.168.1.101", 1081} - s.master.nodesInfo = make(map[string]*Node) - s.master.nodesInfo["alan turning"] = serverNode - s.master.nodesInfo["dennis ritchie"] = workerNode + s.nodesInfo = make(map[string]*Node) + s.nodesInfo["alan turning"] = serverNode + s.nodesInfo["dennis ritchie"] = workerNode // get nodes apitest.New(). Handler(s.handler). @@ -288,12 +284,12 @@ func TestMaster_GetStats(t *testing.T) { s := newMockServer(t) defer s.Close(t) // set stats - s.master.prModelName = "ccd" - err := s.cacheStoreClient.SetString(cache.GlobalMeta, cache.NumItems, "123") + s.rankingModelName = "ccd" + err := s.CacheClient.SetString(cache.GlobalMeta, cache.NumItems, "123") assert.Nil(t, err) - err = s.cacheStoreClient.SetString(cache.GlobalMeta, cache.NumUsers, "234") + err = s.CacheClient.SetString(cache.GlobalMeta, cache.NumUsers, "234") assert.Nil(t, err) - err = s.cacheStoreClient.SetString(cache.GlobalMeta, cache.NumPositiveFeedback, "345") + err = s.CacheClient.SetString(cache.GlobalMeta, cache.NumPositiveFeedback, "345") assert.Nil(t, err) // get stats apitest.New(). @@ -305,7 +301,7 @@ func TestMaster_GetStats(t *testing.T) { NumUsers: "234", NumItems: "123", NumPosFeedback: "345", - PRModel: "ccd", + RankingModel: "ccd", })). End() } @@ -320,11 +316,11 @@ func TestMaster_GetUsers(t *testing.T) { {data.User{UserId: "2"}, "2002-01-01", "2022-01-02"}, } for _, user := range users { - err := s.dataStoreClient.InsertUser(user.User) + err := s.DataClient.InsertUser(user.User) assert.Nil(t, err) - err = s.cacheStoreClient.SetString(cache.LastActiveTime, user.UserId, user.LastActiveTime) + err = s.CacheClient.SetString(cache.LastActiveTime, user.UserId, user.LastActiveTime) assert.Nil(t, err) - err = s.cacheStoreClient.SetString(cache.LastUpdateRecommendTime, user.UserId, user.LastUpdateTime) + err = s.CacheClient.SetString(cache.LastUpdateRecommendTime, user.UserId, user.LastUpdateTime) assert.Nil(t, err) } // get users @@ -371,12 +367,12 @@ func TestServer_List(t *testing.T) { {"3", 97}, {"4", 96}, } - err := s.cacheStoreClient.SetScores(operator.Prefix, operator.Label, itemIds) + err := s.CacheClient.SetScores(operator.Prefix, operator.Label, itemIds) assert.Nil(t, err) items := make([]data.Item, 0) for _, item := range itemIds { items = append(items, data.Item{ItemId: item.ItemId}) - err = s.dataStoreClient.InsertItem(data.Item{ItemId: item.ItemId}) + err = s.DataClient.InsertItem(data.Item{ItemId: item.ItemId}) assert.Nil(t, err) } apitest.New(). @@ -401,7 +397,7 @@ func TestServer_Feedback(t *testing.T) { {FeedbackType: "click", UserId: "0", Item: data.Item{ItemId: "8"}}, } for _, v := range feedback { - err := s.dataStoreClient.InsertFeedback(data.Feedback{ + err := s.DataClient.InsertFeedback(data.Feedback{ FeedbackKey: data.FeedbackKey{FeedbackType: v.FeedbackType, UserId: v.UserId, ItemId: v.Item.ItemId}, }, true, true) assert.Nil(t, err) @@ -430,20 +426,18 @@ func TestServer_GetRecommends(t *testing.T) { {"7", 93}, {"8", 92}, } - err := s.cacheStoreClient.SetScores(cache.CollaborativeItems, "0", itemIds) + err := s.CacheClient.SetScores(cache.CollaborativeItems, "0", itemIds) assert.Nil(t, err) // insert feedback feedback := []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "2"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "4"}}, } - for _, v := range feedback { - err = s.master.InsertFeedbackTwice(v, true, true) - assert.Nil(t, err) - } + err = s.RestServer.InsertFeedbackToCache(feedback) + assert.Nil(t, err) // insert items for _, item := range itemIds { - err = s.dataStoreClient.InsertItem(data.Item{ItemId: item.ItemId}) + err = s.DataClient.InsertItem(data.Item{ItemId: item.ItemId}) assert.Nil(t, err) } apitest.New(). diff --git a/master/rpc.go b/master/rpc.go index 5f1e71662..c899d5049 100644 --- a/master/rpc.go +++ b/master/rpc.go @@ -1,3 +1,17 @@ +// Copyright 2021 gorse Project Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + package master import ( @@ -14,6 +28,7 @@ import ( "strings" ) +// Node could be worker node for server node. type Node struct { Name string Type string @@ -21,6 +36,7 @@ type Node struct { HttpPort int64 } +// NewNode creates a node from Context and NodeInfo. func NewNode(ctx context.Context, nodeInfo *protocol.NodeInfo) *Node { node := new(Node) node.Name = nodeInfo.NodeName @@ -39,8 +55,9 @@ func NewNode(ctx context.Context, nodeInfo *protocol.NodeInfo) *Node { return node } +// GetMeta returns latest configuration. func (m *Master) GetMeta(ctx context.Context, nodeInfo *protocol.NodeInfo) (*protocol.Meta, error) { - // save node + // register node node := NewNode(ctx, nodeInfo) if node.Type != "" { if err := m.ttlCache.Set(nodeInfo.NodeName, node); err != nil { @@ -53,27 +70,20 @@ func (m *Master) GetMeta(ctx context.Context, nodeInfo *protocol.NodeInfo) (*pro if err != nil { return nil, err } - // save user index version + // collect user index version m.userIndexMutex.Lock() var userIndexVersion int64 if m.userIndex != nil { userIndexVersion = m.userIndexVersion } m.userIndexMutex.Unlock() - // save pr version - m.prMutex.Lock() - var prVersion int64 - if m.prModel != nil { - prVersion = m.prVersion + // save ranking model version + m.rankingModelMutex.Lock() + var rankingModelVersion int64 + if m.rankingModel != nil { + rankingModelVersion = m.rankingModelVersion } - m.prMutex.Unlock() - // save fm version - //m.fmMutex.Lock() - //var fmVersion int64 - //if m.fmModel != nil { - // fmVersion = m.ctrVersion - //} - //m.fmMutex.Unlock() + m.rankingModelMutex.Unlock() // collect nodes workers := make([]string, 0) servers := make([]string, 0) @@ -88,53 +98,36 @@ func (m *Master) GetMeta(ctx context.Context, nodeInfo *protocol.NodeInfo) (*pro } m.nodesInfoMutex.Unlock() return &protocol.Meta{ - Config: string(s), - UserIndexVersion: userIndexVersion, - //FmVersion: fmVersion, - PrVersion: prVersion, - Me: nodeInfo.NodeName, - Workers: workers, - Servers: servers, + Config: string(s), + UserIndexVersion: userIndexVersion, + RankingModelVersion: rankingModelVersion, + Me: nodeInfo.NodeName, + Workers: workers, + Servers: servers, }, nil } -func (m *Master) GetPRModel(context.Context, *protocol.NodeInfo) (*protocol.Model, error) { - m.prMutex.Lock() - defer m.prMutex.Unlock() +// GetRankingModel returns latest ranking model. +func (m *Master) GetRankingModel(context.Context, *protocol.NodeInfo) (*protocol.Model, error) { + m.rankingModelMutex.Lock() + defer m.rankingModelMutex.Unlock() // skip empty model - if m.prModel == nil { + if m.rankingModel == nil { return &protocol.Model{Version: 0}, nil } // encode model - modelData, err := ranking.EncodeModel(m.prModel) + modelData, err := ranking.EncodeModel(m.rankingModel) if err != nil { return nil, err } return &protocol.Model{ - Name: m.prModelName, - Version: m.prVersion, + Name: m.rankingModelName, + Version: m.rankingModelVersion, Model: modelData, }, nil } -//func (m *Master) GetFactorizationMachine(context.Context, *protocol.NodeInfo) (*protocol.Model, error) { -// m.fmMutex.Lock() -// defer m.fmMutex.Unlock() -// // skip empty model -// if m.fmModel == nil { -// return &protocol.Model{Version: 0}, nil -// } -// // encode model -// modelData, err := ctr.EncodeModel(m.fmModel) -// if err != nil { -// return nil, err -// } -// return &protocol.Model{ -// Version: m.ctrVersion, -// Model: modelData, -// }, nil -//} - +// GetUserIndex returns latest user index. func (m *Master) GetUserIndex(context.Context, *protocol.NodeInfo) (*protocol.UserIndex, error) { m.userIndexMutex.Lock() defer m.userIndexMutex.Unlock() @@ -158,6 +151,7 @@ func (m *Master) GetUserIndex(context.Context, *protocol.NodeInfo) (*protocol.Us }, nil } +// nodeUp handles node information inserted events. func (m *Master) nodeUp(key string, value interface{}) { node := value.(*Node) base.Logger().Info("node up", @@ -169,6 +163,7 @@ func (m *Master) nodeUp(key string, value interface{}) { m.nodesInfo[key] = node } +// nodeDown handles node information timout events. func (m *Master) nodeDown(key string, value interface{}) { node := value.(*Node) base.Logger().Info("node down", diff --git a/protocol/protocol.pb.go b/protocol/protocol.pb.go index 0c591aff2..249db9f83 100644 --- a/protocol/protocol.pb.go +++ b/protocol/protocol.pb.go @@ -93,13 +93,12 @@ type Meta struct { sizeCache protoimpl.SizeCache unknownFields protoimpl.UnknownFields - Config string `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"` - UserIndexVersion int64 `protobuf:"varint,2,opt,name=user_index_version,json=userIndexVersion,proto3" json:"user_index_version,omitempty"` - PrVersion int64 `protobuf:"varint,3,opt,name=pr_version,json=prVersion,proto3" json:"pr_version,omitempty"` - CtrVersion int64 `protobuf:"varint,4,opt,name=ctr_version,json=ctrVersion,proto3" json:"ctr_version,omitempty"` - Me string `protobuf:"bytes,5,opt,name=me,proto3" json:"me,omitempty"` - Servers []string `protobuf:"bytes,6,rep,name=servers,proto3" json:"servers,omitempty"` - Workers []string `protobuf:"bytes,7,rep,name=workers,proto3" json:"workers,omitempty"` + Config string `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"` + UserIndexVersion int64 `protobuf:"varint,2,opt,name=user_index_version,json=userIndexVersion,proto3" json:"user_index_version,omitempty"` + RankingModelVersion int64 `protobuf:"varint,3,opt,name=ranking_model_version,json=rankingModelVersion,proto3" json:"ranking_model_version,omitempty"` + Me string `protobuf:"bytes,5,opt,name=me,proto3" json:"me,omitempty"` + Servers []string `protobuf:"bytes,6,rep,name=servers,proto3" json:"servers,omitempty"` + Workers []string `protobuf:"bytes,7,rep,name=workers,proto3" json:"workers,omitempty"` } func (x *Meta) Reset() { @@ -148,16 +147,9 @@ func (x *Meta) GetUserIndexVersion() int64 { return 0 } -func (x *Meta) GetPrVersion() int64 { +func (x *Meta) GetRankingModelVersion() int64 { if x != nil { - return x.PrVersion - } - return 0 -} - -func (x *Meta) GetCtrVersion() int64 { - if x != nil { - return x.CtrVersion + return x.RankingModelVersion } return 0 } @@ -368,58 +360,54 @@ var File_protocol_proto protoreflect.FileDescriptor var file_protocol_proto_rawDesc = []byte{ 0x0a, 0x0e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x12, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0xd0, 0x01, 0x0a, 0x04, 0x4d, + 0x12, 0x08, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x22, 0xc4, 0x01, 0x0a, 0x04, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x16, 0x0a, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x63, 0x6f, 0x6e, 0x66, 0x69, 0x67, 0x12, 0x2c, 0x0a, 0x12, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x02, 0x20, 0x01, 0x28, 0x03, 0x52, 0x10, 0x75, 0x73, 0x65, 0x72, 0x49, 0x6e, 0x64, - 0x65, 0x78, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1d, 0x0a, 0x0a, 0x70, 0x72, 0x5f, - 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x70, - 0x72, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1f, 0x0a, 0x0b, 0x63, 0x74, 0x72, 0x5f, - 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x0a, 0x63, - 0x74, 0x72, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, 0x02, 0x6d, 0x65, 0x18, - 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x6d, 0x65, 0x12, 0x18, 0x0a, 0x07, 0x73, 0x65, 0x72, - 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x73, 0x65, 0x72, 0x76, - 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x77, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x73, 0x18, 0x07, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x77, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x73, 0x22, 0x44, 0x0a, - 0x09, 0x55, 0x73, 0x65, 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, - 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x76, 0x65, 0x72, - 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1d, 0x0a, 0x0a, 0x75, 0x73, 0x65, 0x72, 0x5f, 0x69, 0x6e, 0x64, - 0x65, 0x78, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x75, 0x73, 0x65, 0x72, 0x49, 0x6e, - 0x64, 0x65, 0x78, 0x22, 0x4b, 0x0a, 0x05, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x18, 0x0a, 0x07, - 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x07, 0x76, - 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x14, 0x0a, 0x05, 0x6d, 0x6f, - 0x64, 0x65, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x6d, 0x6f, 0x64, 0x65, 0x6c, - 0x22, 0x75, 0x0a, 0x08, 0x4e, 0x6f, 0x64, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x12, 0x2f, 0x0a, 0x09, - 0x6e, 0x6f, 0x64, 0x65, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0e, 0x32, - 0x12, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x4e, 0x6f, 0x64, 0x65, 0x54, - 0x79, 0x70, 0x65, 0x52, 0x08, 0x6e, 0x6f, 0x64, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x1b, 0x0a, - 0x09, 0x6e, 0x6f, 0x64, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x08, 0x6e, 0x6f, 0x64, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x68, 0x74, - 0x74, 0x70, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x08, 0x68, - 0x74, 0x74, 0x70, 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x3a, 0x0a, 0x08, 0x4e, 0x6f, 0x64, 0x65, 0x54, - 0x79, 0x70, 0x65, 0x12, 0x0e, 0x0a, 0x0a, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x4e, 0x6f, 0x64, - 0x65, 0x10, 0x00, 0x12, 0x0e, 0x0a, 0x0a, 0x57, 0x6f, 0x72, 0x6b, 0x65, 0x72, 0x4e, 0x6f, 0x64, - 0x65, 0x10, 0x01, 0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x4e, 0x6f, 0x64, - 0x65, 0x10, 0x02, 0x32, 0xdf, 0x01, 0x0a, 0x06, 0x4d, 0x61, 0x73, 0x74, 0x65, 0x72, 0x12, 0x2f, - 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x12, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x4e, 0x6f, 0x64, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x1a, 0x0e, 0x2e, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x4d, 0x65, 0x74, 0x61, 0x22, 0x00, 0x12, - 0x39, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x55, 0x73, 0x65, 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x12, - 0x12, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x4e, 0x6f, 0x64, 0x65, 0x49, - 0x6e, 0x66, 0x6f, 0x1a, 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x55, - 0x73, 0x65, 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x22, 0x00, 0x12, 0x34, 0x0a, 0x0b, 0x47, 0x65, - 0x74, 0x43, 0x54, 0x52, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x12, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x4e, 0x6f, 0x64, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x1a, 0x0f, 0x2e, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x22, 0x00, - 0x12, 0x33, 0x0a, 0x0a, 0x47, 0x65, 0x74, 0x50, 0x52, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x12, 0x12, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x4e, 0x6f, 0x64, 0x65, 0x49, 0x6e, - 0x66, 0x6f, 0x1a, 0x0f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x4d, 0x6f, - 0x64, 0x65, 0x6c, 0x22, 0x00, 0x42, 0x25, 0x5a, 0x23, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, - 0x63, 0x6f, 0x6d, 0x2f, 0x7a, 0x68, 0x65, 0x6e, 0x67, 0x68, 0x61, 0x6f, 0x7a, 0x2f, 0x67, 0x6f, - 0x72, 0x73, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x62, 0x06, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x33, + 0x65, 0x78, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x32, 0x0a, 0x15, 0x72, 0x61, 0x6e, + 0x6b, 0x69, 0x6e, 0x67, 0x5f, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, + 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x13, 0x72, 0x61, 0x6e, 0x6b, 0x69, 0x6e, + 0x67, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x0e, 0x0a, + 0x02, 0x6d, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x6d, 0x65, 0x12, 0x18, 0x0a, + 0x07, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x18, 0x06, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, + 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x12, 0x18, 0x0a, 0x07, 0x77, 0x6f, 0x72, 0x6b, 0x65, + 0x72, 0x73, 0x18, 0x07, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x77, 0x6f, 0x72, 0x6b, 0x65, 0x72, + 0x73, 0x22, 0x44, 0x0a, 0x09, 0x55, 0x73, 0x65, 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x12, 0x18, + 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, + 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x1d, 0x0a, 0x0a, 0x75, 0x73, 0x65, 0x72, + 0x5f, 0x69, 0x6e, 0x64, 0x65, 0x78, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x09, 0x75, 0x73, + 0x65, 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x22, 0x4b, 0x0a, 0x05, 0x4d, 0x6f, 0x64, 0x65, 0x6c, + 0x12, 0x18, 0x0a, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, + 0x03, 0x52, 0x07, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, + 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x14, + 0x0a, 0x05, 0x6d, 0x6f, 0x64, 0x65, 0x6c, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x05, 0x6d, + 0x6f, 0x64, 0x65, 0x6c, 0x22, 0x75, 0x0a, 0x08, 0x4e, 0x6f, 0x64, 0x65, 0x49, 0x6e, 0x66, 0x6f, + 0x12, 0x2f, 0x0a, 0x09, 0x6e, 0x6f, 0x64, 0x65, 0x5f, 0x74, 0x79, 0x70, 0x65, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0e, 0x32, 0x12, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x4e, + 0x6f, 0x64, 0x65, 0x54, 0x79, 0x70, 0x65, 0x52, 0x08, 0x6e, 0x6f, 0x64, 0x65, 0x54, 0x79, 0x70, + 0x65, 0x12, 0x1b, 0x0a, 0x09, 0x6e, 0x6f, 0x64, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x6e, 0x6f, 0x64, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x1b, + 0x0a, 0x09, 0x68, 0x74, 0x74, 0x70, 0x5f, 0x70, 0x6f, 0x72, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x03, 0x52, 0x08, 0x68, 0x74, 0x74, 0x70, 0x50, 0x6f, 0x72, 0x74, 0x2a, 0x3a, 0x0a, 0x08, 0x4e, + 0x6f, 0x64, 0x65, 0x54, 0x79, 0x70, 0x65, 0x12, 0x0e, 0x0a, 0x0a, 0x53, 0x65, 0x72, 0x76, 0x65, + 0x72, 0x4e, 0x6f, 0x64, 0x65, 0x10, 0x00, 0x12, 0x0e, 0x0a, 0x0a, 0x57, 0x6f, 0x72, 0x6b, 0x65, + 0x72, 0x4e, 0x6f, 0x64, 0x65, 0x10, 0x01, 0x12, 0x0e, 0x0a, 0x0a, 0x43, 0x6c, 0x69, 0x65, 0x6e, + 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x10, 0x02, 0x32, 0xae, 0x01, 0x0a, 0x06, 0x4d, 0x61, 0x73, 0x74, + 0x65, 0x72, 0x12, 0x2f, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4d, 0x65, 0x74, 0x61, 0x12, 0x12, 0x2e, + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x4e, 0x6f, 0x64, 0x65, 0x49, 0x6e, 0x66, + 0x6f, 0x1a, 0x0e, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x4d, 0x65, 0x74, + 0x61, 0x22, 0x00, 0x12, 0x39, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x55, 0x73, 0x65, 0x72, 0x49, 0x6e, + 0x64, 0x65, 0x78, 0x12, 0x12, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x4e, + 0x6f, 0x64, 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x1a, 0x13, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, + 0x6f, 0x6c, 0x2e, 0x55, 0x73, 0x65, 0x72, 0x49, 0x6e, 0x64, 0x65, 0x78, 0x22, 0x00, 0x12, 0x38, + 0x0a, 0x0f, 0x47, 0x65, 0x74, 0x52, 0x61, 0x6e, 0x6b, 0x69, 0x6e, 0x67, 0x4d, 0x6f, 0x64, 0x65, + 0x6c, 0x12, 0x12, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x2e, 0x4e, 0x6f, 0x64, + 0x65, 0x49, 0x6e, 0x66, 0x6f, 0x1a, 0x0f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, + 0x2e, 0x4d, 0x6f, 0x64, 0x65, 0x6c, 0x22, 0x00, 0x42, 0x25, 0x5a, 0x23, 0x67, 0x69, 0x74, 0x68, + 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x7a, 0x68, 0x65, 0x6e, 0x67, 0x68, 0x61, 0x6f, 0x7a, + 0x2f, 0x67, 0x6f, 0x72, 0x73, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x62, + 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( @@ -447,14 +435,12 @@ var file_protocol_proto_depIdxs = []int32{ 0, // 0: protocol.NodeInfo.node_type:type_name -> protocol.NodeType 4, // 1: protocol.Master.GetMeta:input_type -> protocol.NodeInfo 4, // 2: protocol.Master.GetUserIndex:input_type -> protocol.NodeInfo - 4, // 3: protocol.Master.GetCTRModel:input_type -> protocol.NodeInfo - 4, // 4: protocol.Master.GetPRModel:input_type -> protocol.NodeInfo - 1, // 5: protocol.Master.GetMeta:output_type -> protocol.Meta - 2, // 6: protocol.Master.GetUserIndex:output_type -> protocol.UserIndex - 3, // 7: protocol.Master.GetCTRModel:output_type -> protocol.Model - 3, // 8: protocol.Master.GetPRModel:output_type -> protocol.Model - 5, // [5:9] is the sub-list for method output_type - 1, // [1:5] is the sub-list for method input_type + 4, // 3: protocol.Master.GetRankingModel:input_type -> protocol.NodeInfo + 1, // 4: protocol.Master.GetMeta:output_type -> protocol.Meta + 2, // 5: protocol.Master.GetUserIndex:output_type -> protocol.UserIndex + 3, // 6: protocol.Master.GetRankingModel:output_type -> protocol.Model + 4, // [4:7] is the sub-list for method output_type + 1, // [1:4] is the sub-list for method input_type 1, // [1:1] is the sub-list for extension type_name 1, // [1:1] is the sub-list for extension extendee 0, // [0:1] is the sub-list for field type_name diff --git a/protocol/protocol.proto b/protocol/protocol.proto index debf2e082..19b1aab07 100644 --- a/protocol/protocol.proto +++ b/protocol/protocol.proto @@ -30,16 +30,14 @@ service Master { /* data distribute */ rpc GetUserIndex(NodeInfo) returns(UserIndex) {} - rpc GetCTRModel(NodeInfo) returns (Model) {} - rpc GetPRModel(NodeInfo) returns (Model) {} + rpc GetRankingModel(NodeInfo) returns (Model) {} } message Meta { string config = 1; int64 user_index_version = 2; - int64 pr_version = 3; - int64 ctr_version = 4; + int64 ranking_model_version = 3; string me = 5; repeated string servers = 6; repeated string workers = 7; diff --git a/protocol/protocol_grpc.pb.go b/protocol/protocol_grpc.pb.go index 32f00f80a..c50f74737 100644 --- a/protocol/protocol_grpc.pb.go +++ b/protocol/protocol_grpc.pb.go @@ -21,8 +21,7 @@ type MasterClient interface { GetMeta(ctx context.Context, in *NodeInfo, opts ...grpc.CallOption) (*Meta, error) // data distribute GetUserIndex(ctx context.Context, in *NodeInfo, opts ...grpc.CallOption) (*UserIndex, error) - GetCTRModel(ctx context.Context, in *NodeInfo, opts ...grpc.CallOption) (*Model, error) - GetPRModel(ctx context.Context, in *NodeInfo, opts ...grpc.CallOption) (*Model, error) + GetRankingModel(ctx context.Context, in *NodeInfo, opts ...grpc.CallOption) (*Model, error) } type masterClient struct { @@ -51,18 +50,9 @@ func (c *masterClient) GetUserIndex(ctx context.Context, in *NodeInfo, opts ...g return out, nil } -func (c *masterClient) GetCTRModel(ctx context.Context, in *NodeInfo, opts ...grpc.CallOption) (*Model, error) { +func (c *masterClient) GetRankingModel(ctx context.Context, in *NodeInfo, opts ...grpc.CallOption) (*Model, error) { out := new(Model) - err := c.cc.Invoke(ctx, "/protocol.Master/GetCTRModel", in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *masterClient) GetPRModel(ctx context.Context, in *NodeInfo, opts ...grpc.CallOption) (*Model, error) { - out := new(Model) - err := c.cc.Invoke(ctx, "/protocol.Master/GetPRModel", in, out, opts...) + err := c.cc.Invoke(ctx, "/protocol.Master/GetRankingModel", in, out, opts...) if err != nil { return nil, err } @@ -77,8 +67,7 @@ type MasterServer interface { GetMeta(context.Context, *NodeInfo) (*Meta, error) // data distribute GetUserIndex(context.Context, *NodeInfo) (*UserIndex, error) - GetCTRModel(context.Context, *NodeInfo) (*Model, error) - GetPRModel(context.Context, *NodeInfo) (*Model, error) + GetRankingModel(context.Context, *NodeInfo) (*Model, error) mustEmbedUnimplementedMasterServer() } @@ -92,11 +81,8 @@ func (UnimplementedMasterServer) GetMeta(context.Context, *NodeInfo) (*Meta, err func (UnimplementedMasterServer) GetUserIndex(context.Context, *NodeInfo) (*UserIndex, error) { return nil, status.Errorf(codes.Unimplemented, "method GetUserIndex not implemented") } -func (UnimplementedMasterServer) GetCTRModel(context.Context, *NodeInfo) (*Model, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetCTRModel not implemented") -} -func (UnimplementedMasterServer) GetPRModel(context.Context, *NodeInfo) (*Model, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetPRModel not implemented") +func (UnimplementedMasterServer) GetRankingModel(context.Context, *NodeInfo) (*Model, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetRankingModel not implemented") } func (UnimplementedMasterServer) mustEmbedUnimplementedMasterServer() {} @@ -147,38 +133,20 @@ func _Master_GetUserIndex_Handler(srv interface{}, ctx context.Context, dec func return interceptor(ctx, in, info, handler) } -func _Master_GetCTRModel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(NodeInfo) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(MasterServer).GetCTRModel(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: "/protocol.Master/GetCTRModel", - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(MasterServer).GetCTRModel(ctx, req.(*NodeInfo)) - } - return interceptor(ctx, in, info, handler) -} - -func _Master_GetPRModel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { +func _Master_GetRankingModel_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(NodeInfo) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(MasterServer).GetPRModel(ctx, in) + return srv.(MasterServer).GetRankingModel(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: "/protocol.Master/GetPRModel", + FullMethod: "/protocol.Master/GetRankingModel", } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(MasterServer).GetPRModel(ctx, req.(*NodeInfo)) + return srv.(MasterServer).GetRankingModel(ctx, req.(*NodeInfo)) } return interceptor(ctx, in, info, handler) } @@ -196,12 +164,8 @@ var _Master_serviceDesc = grpc.ServiceDesc{ Handler: _Master_GetUserIndex_Handler, }, { - MethodName: "GetCTRModel", - Handler: _Master_GetCTRModel_Handler, - }, - { - MethodName: "GetPRModel", - Handler: _Master_GetPRModel_Handler, + MethodName: "GetRankingModel", + Handler: _Master_GetRankingModel_Handler, }, }, Streams: []grpc.StreamDesc{}, diff --git a/server/local_cache.go b/server/local_cache.go index d25ace9ac..27249c1b4 100644 --- a/server/local_cache.go +++ b/server/local_cache.go @@ -20,11 +20,13 @@ import ( "path/filepath" ) +// LocalCache is local cache for the server node. type LocalCache struct { path string ServerName string } +// LoadLocalCache loads local cache from a file. func LoadLocalCache(path string) (*LocalCache, error) { state := &LocalCache{path: path} // check if file exists @@ -47,6 +49,7 @@ func LoadLocalCache(path string) (*LocalCache, error) { return state, nil } +// WriteLocalCache writes local cache to a file. func (s *LocalCache) WriteLocalCache() error { // create parent folder if not exists parent := filepath.Dir(s.path) diff --git a/server/rest.go b/server/rest.go index 2a15d691a..3e276a827 100644 --- a/server/rest.go +++ b/server/rest.go @@ -32,9 +32,10 @@ import ( "go.uber.org/zap" ) +// RestServer implements a REST-ful API server. type RestServer struct { - CacheStore cache.Database - DataStore data.Database + CacheClient cache.Database + DataClient data.Database GorseConfig *config.Config HttpHost string HttpPort int @@ -42,6 +43,7 @@ type RestServer struct { WebService *restful.WebService } +// StartHttpServer starts the REST-ful API server. func (s *RestServer) StartHttpServer() { // register restful APIs s.CreateWebService() @@ -63,6 +65,7 @@ func (s *RestServer) StartHttpServer() { zap.Error(http.ListenAndServe(fmt.Sprintf("%s:%d", s.HttpHost, s.HttpPort), nil))) } +// CreateWebService creates web service. func (s *RestServer) CreateWebService() { // Create a server ws := s.WebService @@ -233,16 +236,6 @@ func (s *RestServer) CreateWebService() { Param(ws.QueryParameter("n", "number of returned items").DataType("int")). Param(ws.QueryParameter("offset", "offset of the list").DataType("int")). Writes([]string{})) - // Get subscribe items - //ws.Route(ws.GET("/intermediate/subscribe/{user-id}").To(s.getSubscribe). - // Doc("get subscribe items for a user"). - // Metadata(restfulspec.KeyOpenAPITags, []string{"intermediate"}). - // Param(ws.HeaderParameter("X-API-Key", "secret key for RESTful API")). - // Param(ws.QueryParameter("user-id", "identifier of the user").DataType("string")). - // Param(ws.QueryParameter("n", "number of returned items").DataType("int")). - // Param(ws.QueryParameter("offset", "offset of the list").DataType("int")). - // Param(ws.QueryParameter("return", "return type (id/detail)").DataType("string")). - // Writes([]string{})) /* Rank recommendation */ @@ -304,6 +297,7 @@ func (s *RestServer) CreateWebService() { Writes([]data.Measurement{})) } +// ParseInt parses integers from the query parameter. func ParseInt(request *restful.Request, name string, fallback int) (value int, err error) { valueString := request.QueryParameter(name) value, err = strconv.Atoi(valueString) @@ -326,7 +320,7 @@ func (s *RestServer) getList(prefix string, name string, request *restful.Reques return } // Get the popular list - items, err := s.CacheStore.GetScores(prefix, name, begin, end) + items, err := s.CacheClient.GetScores(prefix, name, begin, end) if err != nil { InternalServerError(response, err) return @@ -382,7 +376,7 @@ func (s *RestServer) getTypedFeedbackByItem(request *restful.Request, response * } feedbackType := request.PathParameter("feedback-type") itemId := request.PathParameter("item-id") - feedback, err := s.DataStore.GetItemFeedback(itemId, feedbackType) + feedback, err := s.DataClient.GetItemFeedback(itemId, feedbackType) if err != nil { InternalServerError(response, err) return @@ -397,7 +391,7 @@ func (s *RestServer) getFeedbackByItem(request *restful.Request, response *restf return } itemId := request.PathParameter("item-id") - feedback, err := s.DataStore.GetItemFeedback(itemId) + feedback, err := s.DataClient.GetItemFeedback(itemId) if err != nil { InternalServerError(response, err) return @@ -438,17 +432,6 @@ func (s *RestServer) getCollaborative(request *restful.Request, response *restfu s.getList(cache.CollaborativeItems, userId, request, response) } -// InsertFeedbackTwice insert feedback both to database and cache. -func (s *RestServer) InsertFeedbackTwice(feedback data.Feedback, insertUser, insertItem bool) error { - // 1. insert feedback to database - err := s.DataStore.InsertFeedback(feedback, insertUser, insertItem) - if err != nil { - return err - } - // 2. insert feedback to cache - return s.CacheStore.AppendList(cache.IgnoreItems, feedback.UserId, feedback.ItemId) -} - // Recommend items to users. // 1. If there are recommendations in cache, return cached recommendations. // 2. If there are historical interactions of the users, return similar items. @@ -463,7 +446,7 @@ func (s *RestServer) Recommend(userId string, n int) ([]string, error) { errChan := make(chan error, 1) go func() { var collaborativeFilteringItems []cache.ScoredItem - collaborativeFilteringItems, err = s.CacheStore.GetScores(cache.CollaborativeItems, userId, 0, s.GorseConfig.Database.CacheSize) + collaborativeFilteringItems, err = s.CacheClient.GetScores(cache.CollaborativeItems, userId, 0, s.GorseConfig.Database.CacheSize) if err != nil { itemsChan <- nil errChan <- err @@ -478,7 +461,7 @@ func (s *RestServer) Recommend(userId string, n int) ([]string, error) { // 0. load ignore items loadCachedReadStart := time.Now() - ignoreItems, err := s.CacheStore.GetList(cache.IgnoreItems, userId) + ignoreItems, err := s.CacheClient.GetList(cache.IgnoreItems, userId) excludeSet := set.NewStringSet() for _, item := range ignoreItems { excludeSet.Add(item) @@ -504,7 +487,7 @@ func (s *RestServer) Recommend(userId string, n int) ([]string, error) { if len(results) < n { // load historical feedback loadArchReadStart := time.Now() - userFeedback, err := s.DataStore.GetUserFeedback(userId) + userFeedback, err := s.DataClient.GetUserFeedback(userId) if err != nil { return nil, err } @@ -517,7 +500,7 @@ func (s *RestServer) Recommend(userId string, n int) ([]string, error) { candidates := make(map[string]float32) for _, feedback := range userFeedback { // load similar items - similarItems, err := s.CacheStore.GetScores(cache.SimilarItems, feedback.ItemId, 0, s.GorseConfig.Database.CacheSize) + similarItems, err := s.CacheClient.GetScores(cache.SimilarItems, feedback.ItemId, 0, s.GorseConfig.Database.CacheSize) if err != nil { return nil, err } @@ -547,9 +530,9 @@ func (s *RestServer) Recommend(userId string, n int) ([]string, error) { var fallbacks []cache.ScoredItem switch s.GorseConfig.Recommend.FallbackRecommend { case "latest": - fallbacks, err = s.CacheStore.GetScores(cache.LatestItems, "", 0, s.GorseConfig.Database.CacheSize) + fallbacks, err = s.CacheClient.GetScores(cache.LatestItems, "", 0, s.GorseConfig.Database.CacheSize) case "popular": - fallbacks, err = s.CacheStore.GetScores(cache.PopularItems, "", 0, s.GorseConfig.Database.CacheSize) + fallbacks, err = s.CacheClient.GetScores(cache.PopularItems, "", 0, s.GorseConfig.Database.CacheSize) default: return nil, fmt.Errorf("unknown fallback recommendation method `%s`", s.GorseConfig.Recommend.FallbackRecommend) } @@ -599,14 +582,22 @@ func (s *RestServer) getRecommend(request *restful.Request, response *restful.Re // write back if writeBackFeedback != "" { for _, itemId := range results { - err = s.InsertFeedbackTwice(data.Feedback{ + // insert to data store + feedback := data.Feedback{ FeedbackKey: data.FeedbackKey{ UserId: userId, ItemId: itemId, FeedbackType: writeBackFeedback, }, Timestamp: time.Now(), - }, false, false) + } + err = s.DataClient.InsertFeedback(feedback, false, false) + if err != nil { + InternalServerError(response, err) + return + } + // insert to cache store + err = s.InsertFeedbackToCache([]data.Feedback{feedback}) if err != nil { InternalServerError(response, err) return @@ -617,6 +608,7 @@ func (s *RestServer) getRecommend(request *restful.Request, response *restful.Re Ok(response, results) } +// Success is the returned data structure for data insert operations. type Success struct { RowAffected int } @@ -632,7 +624,7 @@ func (s *RestServer) insertUser(request *restful.Request, response *restful.Resp BadRequest(response, err) return } - if err := s.DataStore.InsertUser(temp); err != nil { + if err := s.DataClient.InsertUser(temp); err != nil { InternalServerError(response, err) return } @@ -647,7 +639,7 @@ func (s *RestServer) getUser(request *restful.Request, response *restful.Respons // get user id userId := request.PathParameter("user-id") // get user - user, err := s.DataStore.GetUser(userId) + user, err := s.DataClient.GetUser(userId) if err != nil { if err == data.ErrUserNotExist { PageNotFound(response, err) @@ -673,7 +665,7 @@ func (s *RestServer) insertUsers(request *restful.Request, response *restful.Res var count int // range temp and achieve user for _, user := range *temp { - if err := s.DataStore.InsertUser(user); err != nil { + if err := s.DataClient.InsertUser(user); err != nil { InternalServerError(response, err) return } @@ -699,7 +691,7 @@ func (s *RestServer) getUsers(request *restful.Request, response *restful.Respon return } // get all users - cursor, users, err := s.DataStore.GetUsers(cursor, n) + cursor, users, err := s.DataClient.GetUsers(cursor, n) if err != nil { InternalServerError(response, err) return @@ -715,7 +707,7 @@ func (s *RestServer) deleteUser(request *restful.Request, response *restful.Resp } // get user-id and put into temp userId := request.PathParameter("user-id") - if err := s.DataStore.DeleteUser(userId); err != nil { + if err := s.DataClient.DeleteUser(userId); err != nil { InternalServerError(response, err) return } @@ -730,7 +722,7 @@ func (s *RestServer) getTypedFeedbackByUser(request *restful.Request, response * } feedbackType := request.PathParameter("feedback-type") userId := request.PathParameter("user-id") - feedback, err := s.DataStore.GetUserFeedback(userId, feedbackType) + feedback, err := s.DataClient.GetUserFeedback(userId, feedbackType) if err != nil { InternalServerError(response, err) return @@ -745,7 +737,7 @@ func (s *RestServer) getFeedbackByUser(request *restful.Request, response *restf return } userId := request.PathParameter("user-id") - feedback, err := s.DataStore.GetUserFeedback(userId) + feedback, err := s.DataClient.GetUserFeedback(userId) if err != nil { InternalServerError(response, err) return @@ -761,7 +753,6 @@ type Item struct { Comment string } -// putItems puts items into the database. func (s *RestServer) insertItems(request *restful.Request, response *restful.Response) { // Authorize if !s.auth(request, response) { @@ -782,7 +773,7 @@ func (s *RestServer) insertItems(request *restful.Request, response *restful.Res BadRequest(response, err) return } - err = s.DataStore.InsertItem(data.Item{ItemId: item.ItemId, Timestamp: timestamp, Labels: item.Labels, Comment: item.Comment}) + err = s.DataClient.InsertItem(data.Item{ItemId: item.ItemId, Timestamp: timestamp, Labels: item.Labels, Comment: item.Comment}) count++ if err != nil { InternalServerError(response, err) @@ -809,13 +800,14 @@ func (s *RestServer) insertItem(request *restful.Request, response *restful.Resp BadRequest(response, err) return } - if err = s.DataStore.InsertItem(data.Item{ItemId: item.ItemId, Timestamp: timestamp, Labels: item.Labels, Comment: item.Comment}); err != nil { + if err = s.DataClient.InsertItem(data.Item{ItemId: item.ItemId, Timestamp: timestamp, Labels: item.Labels, Comment: item.Comment}); err != nil { InternalServerError(response, err) return } Ok(response, Success{RowAffected: 1}) } +// ItemIterator is the iterator for items. type ItemIterator struct { Cursor string Items []data.Item @@ -832,7 +824,7 @@ func (s *RestServer) getItems(request *restful.Request, response *restful.Respon BadRequest(response, err) return } - cursor, items, err := s.DataStore.GetItems(cursor, n, nil) + cursor, items, err := s.DataClient.GetItems(cursor, n, nil) if err != nil { InternalServerError(response, err) return @@ -848,7 +840,7 @@ func (s *RestServer) getItem(request *restful.Request, response *restful.Respons // Get item id itemId := request.PathParameter("item-id") // Get item - item, err := s.DataStore.GetItem(itemId) + item, err := s.DataClient.GetItem(itemId) if err != nil { if err == data.ErrItemNotExist { PageNotFound(response, err) @@ -866,7 +858,7 @@ func (s *RestServer) deleteItem(request *restful.Request, response *restful.Resp return } itemId := request.PathParameter("item-id") - if err := s.DataStore.DeleteItem(itemId); err != nil { + if err := s.DataClient.DeleteItem(itemId); err != nil { InternalServerError(response, err) return } @@ -880,7 +872,6 @@ type Feedback struct { Comment string } -// putFeedback puts new ratings into the database. func (s *RestServer) insertFeedback(request *restful.Request, response *restful.Response) { // authorize if !s.auth(request, response) { @@ -906,9 +897,9 @@ func (s *RestServer) insertFeedback(request *restful.Request, response *restful. return } } - // Insert feedback + // insert feedback to data store for _, v := range feedback { - err = s.InsertFeedbackTwice(v, + err = s.DataClient.InsertFeedback(v, s.GorseConfig.Database.AutoInsertUser, s.GorseConfig.Database.AutoInsertItem) if err != nil { @@ -916,8 +907,14 @@ func (s *RestServer) insertFeedback(request *restful.Request, response *restful. return } } + // insert feedback to cache store + if err = s.InsertFeedbackToCache(feedback); err != nil { + InternalServerError(response, err) + return + } + for _, userId := range users.List() { - err = s.CacheStore.SetString(cache.LastActiveTime, userId, base.Now()) + err = s.CacheClient.SetString(cache.LastActiveTime, userId, base.Now()) if err != nil { InternalServerError(response, err) return @@ -926,12 +923,12 @@ func (s *RestServer) insertFeedback(request *restful.Request, response *restful. Ok(response, Success{RowAffected: len(feedback)}) } +// FeedbackIterator is the iterator for feedback. type FeedbackIterator struct { Cursor string Feedback []data.Feedback } -// Get feedback func (s *RestServer) getFeedback(request *restful.Request, response *restful.Response) { // Authorize if !s.auth(request, response) { @@ -944,7 +941,7 @@ func (s *RestServer) getFeedback(request *restful.Request, response *restful.Res BadRequest(response, err) return } - cursor, feedback, err := s.DataStore.GetFeedback(cursor, n, nil) + cursor, feedback, err := s.DataClient.GetFeedback(cursor, n, nil) if err != nil { InternalServerError(response, err) return @@ -965,7 +962,7 @@ func (s *RestServer) getTypedFeedback(request *restful.Request, response *restfu BadRequest(response, err) return } - cursor, feedback, err := s.DataStore.GetFeedback(cursor, n, nil, feedbackType) + cursor, feedback, err := s.DataClient.GetFeedback(cursor, n, nil, feedbackType) if err != nil { InternalServerError(response, err) return @@ -981,7 +978,7 @@ func (s *RestServer) getUserItemFeedback(request *restful.Request, response *res // Parse parameters userId := request.PathParameter("user-id") itemId := request.PathParameter("item-id") - if feedback, err := s.DataStore.GetUserItemFeedback(userId, itemId); err != nil { + if feedback, err := s.DataClient.GetUserItemFeedback(userId, itemId); err != nil { InternalServerError(response, err) } else { Ok(response, feedback) @@ -996,7 +993,7 @@ func (s *RestServer) deleteUserItemFeedback(request *restful.Request, response * // Parse parameters userId := request.PathParameter("user-id") itemId := request.PathParameter("item-id") - if deleteCount, err := s.DataStore.DeleteUserItemFeedback(userId, itemId); err != nil { + if deleteCount, err := s.DataClient.DeleteUserItemFeedback(userId, itemId); err != nil { InternalServerError(response, err) } else { Ok(response, Success{RowAffected: deleteCount}) @@ -1012,7 +1009,7 @@ func (s *RestServer) getTypedUserItemFeedback(request *restful.Request, response feedbackType := request.PathParameter("feedback-type") userId := request.PathParameter("user-id") itemId := request.PathParameter("item-id") - if feedback, err := s.DataStore.GetUserItemFeedback(userId, itemId, feedbackType); err != nil { + if feedback, err := s.DataClient.GetUserItemFeedback(userId, itemId, feedbackType); err != nil { InternalServerError(response, err) } else if len(feedbackType) == 0 { Text(response, "{}") @@ -1030,7 +1027,7 @@ func (s *RestServer) deleteTypedUserItemFeedback(request *restful.Request, respo feedbackType := request.PathParameter("feedback-type") userId := request.PathParameter("user-id") itemId := request.PathParameter("item-id") - if deleteCount, err := s.DataStore.DeleteUserItemFeedback(userId, itemId, feedbackType); err != nil { + if deleteCount, err := s.DataClient.DeleteUserItemFeedback(userId, itemId, feedbackType); err != nil { InternalServerError(response, err) } else { Ok(response, Success{deleteCount}) @@ -1049,7 +1046,7 @@ func (s *RestServer) getMeasurements(request *restful.Request, response *restful BadRequest(response, err) return } - measurements, err := s.DataStore.GetMeasurements(name, n) + measurements, err := s.DataClient.GetMeasurements(name, n) if err != nil { InternalServerError(response, err) return @@ -1057,6 +1054,7 @@ func (s *RestServer) getMeasurements(request *restful.Request, response *restful Ok(response, measurements) } +// BadRequest returns a bad request error. func BadRequest(response *restful.Response, err error) { response.Header().Set("Access-Control-Allow-Origin", "*") base.Logger().Error("bad request", zap.Error(err)) @@ -1065,6 +1063,7 @@ func BadRequest(response *restful.Response, err error) { } } +// InternalServerError returns a internal server error. func InternalServerError(response *restful.Response, err error) { response.Header().Set("Access-Control-Allow-Origin", "*") base.Logger().Error("internal server error", zap.Error(err)) @@ -1073,6 +1072,7 @@ func InternalServerError(response *restful.Response, err error) { } } +// PageNotFound returns a not found error. func PageNotFound(response *restful.Response, err error) { response.Header().Set("Access-Control-Allow-Origin", "*") if err := response.WriteError(http.StatusNotFound, err); err != nil { @@ -1088,6 +1088,7 @@ func Ok(response *restful.Response, content interface{}) { } } +// Text returns a plain text. func Text(response *restful.Response, content string) { response.Header().Set("Access-Control-Allow-Origin", "*") if _, err := response.Write([]byte(content)); err != nil { @@ -1111,3 +1112,14 @@ func (s *RestServer) auth(request *restful.Request, response *restful.Response) } return false } + +// InsertFeedbackToCache inserts feedback to cache. +func (s *RestServer) InsertFeedbackToCache(feedback []data.Feedback) error { + for _, v := range feedback { + err := s.CacheClient.AppendList(cache.IgnoreItems, v.UserId, v.ItemId) + if err != nil { + return err + } + } + return nil +} diff --git a/server/rest_test.go b/server/rest_test.go index ee4620233..6d7733c60 100644 --- a/server/rest_test.go +++ b/server/rest_test.go @@ -33,10 +33,8 @@ const apiKey = "test_api_key" type mockServer struct { dataStoreServer *miniredis.Miniredis cacheStoreServer *miniredis.Miniredis - dataStoreClient data.Database - cacheStoreClient cache.Database handler *restful.Container - server *RestServer + RestServer } func newMockServer(t *testing.T) *mockServer { @@ -48,29 +46,28 @@ func newMockServer(t *testing.T) *mockServer { s.cacheStoreServer, err = miniredis.Run() assert.Nil(t, err) // open database - s.dataStoreClient, err = data.Open("redis://" + s.dataStoreServer.Addr()) + s.DataClient, err = data.Open("redis://" + s.dataStoreServer.Addr()) assert.Nil(t, err) - s.cacheStoreClient, err = cache.Open("redis://" + s.cacheStoreServer.Addr()) + s.CacheClient, err = cache.Open("redis://" + s.cacheStoreServer.Addr()) assert.Nil(t, err) - // create server - server := &RestServer{ - DataStore: s.dataStoreClient, - CacheStore: s.cacheStoreClient, - GorseConfig: (*config.Config)(nil).LoadDefaultIfNil(), - } - server.GorseConfig.Server.APIKey = apiKey - server.WebService = new(restful.WebService) - server.CreateWebService() + // configuration + s.GorseConfig = (*config.Config)(nil).LoadDefaultIfNil() + s.GorseConfig.Server.APIKey = apiKey + s.WebService = new(restful.WebService) + s.CreateWebService() // create handler s.handler = restful.NewContainer() - s.handler.Add(server.WebService) - s.server = server + s.handler.Add(s.WebService) return s } func (s *mockServer) Close(t *testing.T) { - err := s.dataStoreClient.Close() + err := s.DataClient.Close() + assert.Nil(t, err) + err = s.CacheClient.Close() assert.Nil(t, err) + s.dataStoreServer.Close() + s.cacheStoreServer.Close() } func marshal(t *testing.T, v interface{}) string { @@ -355,7 +352,7 @@ func TestServer_List(t *testing.T) { {"3", 97}, {"4", 96}, } - err := s.cacheStoreClient.SetScores(operator.Prefix, operator.Label, items) + err := s.CacheClient.SetScores(operator.Prefix, operator.Label, items) assert.Nil(t, err) apitest.New(). Handler(s.handler). @@ -472,7 +469,7 @@ func TestServer_Measurement(t *testing.T) { {"Test_Recall", time.Date(2000, 1, 1, 1, 1, 1, 0, time.UTC), 1, "f"}, } for _, measurement := range measurements { - err := s.dataStoreClient.InsertMeasurement(measurement) + err := s.DataClient.InsertMeasurement(measurement) assert.Nil(t, err) } apitest.New(). @@ -494,7 +491,7 @@ func TestServer_GetRecommends(t *testing.T) { s := newMockServer(t) defer s.Close(t) // insert recommendation - err := s.cacheStoreClient.SetScores(cache.CollaborativeItems, "0", + err := s.CacheClient.SetScores(cache.CollaborativeItems, "0", []cache.ScoredItem{ {"1", 99}, {"2", 98}, @@ -511,10 +508,15 @@ func TestServer_GetRecommends(t *testing.T) { {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "2"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "4"}}, } - for _, v := range feedback { - err = s.server.InsertFeedbackTwice(v, true, true) - assert.Nil(t, err) - } + apitest.New(). + Handler(s.handler). + Post("/api/feedback"). + Header("X-API-Key", apiKey). + JSON(feedback). + Expect(t). + Status(http.StatusOK). + Body(`{"RowAffected": 2}`). + End() apitest.New(). Handler(s.handler). Get("/api/recommend/0"). @@ -555,7 +557,7 @@ func TestServer_GetRecommends_Fallback_Similar(t *testing.T) { s := newMockServer(t) defer s.Close(t) // insert recommendation - err := s.cacheStoreClient.SetScores(cache.CollaborativeItems, "0", + err := s.CacheClient.SetScores(cache.CollaborativeItems, "0", []cache.ScoredItem{{"1", 99}, {"2", 98}, {"3", 97}, {"4", 96}}) assert.Nil(t, err) // insert feedback @@ -565,30 +567,35 @@ func TestServer_GetRecommends_Fallback_Similar(t *testing.T) { {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "3"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "4"}}, } - for _, v := range feedback { - err = s.server.InsertFeedbackTwice(v, true, true) - assert.Nil(t, err) - } + apitest.New(). + Handler(s.handler). + Post("/api/feedback"). + Header("X-API-Key", apiKey). + JSON(feedback). + Expect(t). + Status(http.StatusOK). + Body(`{"RowAffected": 4}`). + End() // insert similar items - err = s.cacheStoreClient.SetScores(cache.SimilarItems, "1", []cache.ScoredItem{ + err = s.CacheClient.SetScores(cache.SimilarItems, "1", []cache.ScoredItem{ {"2", 100000}, {"9", 1}, }) assert.Nil(t, err) - err = s.cacheStoreClient.SetScores(cache.SimilarItems, "2", []cache.ScoredItem{ + err = s.CacheClient.SetScores(cache.SimilarItems, "2", []cache.ScoredItem{ {"3", 100000}, {"8", 1}, {"9", 1}, }) assert.Nil(t, err) - err = s.cacheStoreClient.SetScores(cache.SimilarItems, "3", []cache.ScoredItem{ + err = s.CacheClient.SetScores(cache.SimilarItems, "3", []cache.ScoredItem{ {"4", 100000}, {"7", 1}, {"8", 1}, {"9", 1}, }) assert.Nil(t, err) - err = s.cacheStoreClient.SetScores(cache.SimilarItems, "4", []cache.ScoredItem{ + err = s.CacheClient.SetScores(cache.SimilarItems, "4", []cache.ScoredItem{ {"1", 100000}, {"6", 1}, {"7", 1}, @@ -597,7 +604,7 @@ func TestServer_GetRecommends_Fallback_Similar(t *testing.T) { }) assert.Nil(t, err) // test fallback - s.server.GorseConfig.Recommend.FallbackRecommend = "popular" + s.GorseConfig.Recommend.FallbackRecommend = "popular" apitest.New(). Handler(s.handler). Get("/api/recommend/0"). @@ -615,19 +622,19 @@ func TestServer_GetRecommends_Fallback_NonPersonalized(t *testing.T) { s := newMockServer(t) defer s.Close(t) // insert recommendation - err := s.cacheStoreClient.SetScores(cache.CollaborativeItems, "0", + err := s.CacheClient.SetScores(cache.CollaborativeItems, "0", []cache.ScoredItem{{"1", 99}, {"2", 98}, {"3", 97}, {"4", 96}}) assert.Nil(t, err) // insert latest - err = s.cacheStoreClient.SetScores(cache.LatestItems, "", + err = s.CacheClient.SetScores(cache.LatestItems, "", []cache.ScoredItem{{"5", 95}, {"6", 94}, {"7", 93}, {"8", 92}}) assert.Nil(t, err) // insert popular - err = s.cacheStoreClient.SetScores(cache.PopularItems, "", + err = s.CacheClient.SetScores(cache.PopularItems, "", []cache.ScoredItem{{"9", 91}, {"10", 90}, {"11", 89}, {"12", 88}}) assert.Nil(t, err) // test popular fallback - s.server.GorseConfig.Recommend.FallbackRecommend = "popular" + s.GorseConfig.Recommend.FallbackRecommend = "popular" apitest.New(). Handler(s.handler). Get("/api/recommend/0"). @@ -640,7 +647,7 @@ func TestServer_GetRecommends_Fallback_NonPersonalized(t *testing.T) { Body(marshal(t, []string{"1", "2", "3", "4", "9", "10", "11", "12"})). End() // test latest fallback - s.server.GorseConfig.Recommend.FallbackRecommend = "latest" + s.GorseConfig.Recommend.FallbackRecommend = "latest" apitest.New(). Handler(s.handler). Get("/api/recommend/0"). @@ -653,7 +660,7 @@ func TestServer_GetRecommends_Fallback_NonPersonalized(t *testing.T) { Body(marshal(t, []string{"1", "2", "3", "4", "5", "6", "7", "8"})). End() // test wrong fallback - s.server.GorseConfig.Recommend.FallbackRecommend = "" + s.GorseConfig.Recommend.FallbackRecommend = "" apitest.New(). Handler(s.handler). Get("/api/recommend/0"). diff --git a/server/server.go b/server/server.go index 731c8151c..c391c33cd 100644 --- a/server/server.go +++ b/server/server.go @@ -33,37 +33,25 @@ import ( "google.golang.org/grpc" ) +// Server manages states of a server node. type Server struct { RestServer - // database connections - cacheAddress string - dataAddress string - - // master connection + cachePath string + dataPath string masterClient protocol.MasterClient - - // factorization machine - //fmModel ctr.FactorizationMachine - //RankModelMutex sync.RWMutex - //fmVersion int64 - //latestFMVersion int64 - - // config - serverName string - masterHost string - masterPort int - - // events - //syncedChan chan bool + serverName string + masterHost string + masterPort int } +// NewServer creates a server node. func NewServer(masterHost string, masterPort int, serverHost string, serverPort int) *Server { return &Server{ masterHost: masterHost, masterPort: masterPort, RestServer: RestServer{ - DataStore: &data.NoDatabase{}, - CacheStore: &cache.NoDatabase{}, + DataClient: &data.NoDatabase{}, + CacheClient: &cache.NoDatabase{}, GorseConfig: (*config.Config)(nil).LoadDefaultIfNil(), HttpHost: serverHost, HttpPort: serverPort, @@ -73,6 +61,7 @@ func NewServer(masterHost string, masterPort int, serverHost string, serverPort } } +// Serve starts a server node. func (s *Server) Serve() { rand.Seed(time.Now().UTC().UnixNano()) // open local store @@ -104,33 +93,9 @@ func (s *Server) Serve() { s.masterClient = protocol.NewMasterClient(conn) go s.Sync() - //go s.Pull() s.StartHttpServer() } -// Pull factorization machine. -//func (s *RestServer) Pull() { -// defer base.CheckPanic() -// for range s.syncedChan { -// ctx := context.Background() -// // pull factorization machine -// if s.latestFMVersion != s.fmVersion { -// base.Logger().Info("pull factorization machine") -// if mfResponse, err := s.masterClient.GetFactorizationMachine(ctx, &protocol.RequestInfo{}, grpc.MaxCallRecvMsgSize(10e8)); err != nil { -// base.Logger().Error("failed to pull factorization machine", zap.Error(err)) -// } else { -// s.fmModel, err = ctr.DecodeModel(mfResponse.Model) -// if err != nil { -// base.Logger().Error("failed to decode factorization machine", zap.Error(err)) -// } else { -// s.fmVersion = mfResponse.Version -// base.Logger().Info("synced factorization machine", zap.Int64("version", s.fmVersion)) -// } -// } -// } -// } -//} - // Sync this server to the master. func (s *Server) Sync() { defer base.CheckPanic() @@ -156,33 +121,25 @@ func (s *Server) Sync() { } // connect to data store - if s.dataAddress != s.GorseConfig.Database.DataStore { + if s.dataPath != s.GorseConfig.Database.DataStore { base.Logger().Info("connect data store", zap.String("database", s.GorseConfig.Database.DataStore)) - if s.DataStore, err = data.Open(s.GorseConfig.Database.DataStore); err != nil { + if s.DataClient, err = data.Open(s.GorseConfig.Database.DataStore); err != nil { base.Logger().Error("failed to connect data store", zap.Error(err)) goto sleep } - s.dataAddress = s.GorseConfig.Database.DataStore + s.dataPath = s.GorseConfig.Database.DataStore } // connect to cache store - if s.cacheAddress != s.GorseConfig.Database.CacheStore { + if s.cachePath != s.GorseConfig.Database.CacheStore { base.Logger().Info("connect cache store", zap.String("database", s.GorseConfig.Database.CacheStore)) - if s.CacheStore, err = cache.Open(s.GorseConfig.Database.CacheStore); err != nil { + if s.CacheClient, err = cache.Open(s.GorseConfig.Database.CacheStore); err != nil { base.Logger().Error("failed to connect cache store", zap.Error(err)) goto sleep } - s.cacheAddress = s.GorseConfig.Database.CacheStore + s.cachePath = s.GorseConfig.Database.CacheStore } - // check FM version - //s.latestFMVersion = meta.FmVersion - //if s.latestFMVersion != s.fmVersion { - // base.Logger().Info("new factorization machine model found", - // zap.Int64("old_version", s.fmVersion), - // zap.Int64("new_version", s.latestFMVersion)) - // s.syncedChan <- true - //} sleep: time.Sleep(time.Duration(s.GorseConfig.Master.MetaTimeout) * time.Second) } diff --git a/server/swagger.go b/server/swagger.go index 12ee28a89..67138dcba 100644 --- a/server/swagger.go +++ b/server/swagger.go @@ -11,6 +11,7 @@ // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. + package server import ( diff --git a/storage/cache/database.go b/storage/cache/database.go index 3a16e74b5..d370d2f4c 100644 --- a/storage/cache/database.go +++ b/storage/cache/database.go @@ -19,6 +19,7 @@ import ( "github.com/go-redis/redis/v8" "github.com/pkg/errors" "strings" + "time" ) const ( @@ -30,14 +31,12 @@ const ( CollaborativeItems = "collaborative_items" SubscribeItems = "subscribe_items" - GlobalMeta = "global_meta" - CollectPopularTime = "last_update_popular_time" - CollectLatestTime = "last_update_latest_time" - CollectSimilarTime = "last_update_similar_time" - FitMatrixFactorizationTime = "last_fit_match_model_time" - FitFactorizationMachineTime = "last_fit_rank_model_time" - MatrixFactorizationVersion = "latest_match_model_version" - FactorizationMachineVersion = "latest_rank_model_version" + GlobalMeta = "global_meta" + CollectPopularTime = "last_update_popular_time" + CollectLatestTime = "last_update_latest_time" + CollectSimilarTime = "last_update_similar_time" + FitMatrixFactorizationTime = "last_fit_match_model_time" + MatrixFactorizationVersion = "latest_match_model_version" LastActiveTime = "last_active_time" LastUpdateRecommendTime = "last_update_recommend_time" @@ -87,6 +86,8 @@ type Database interface { GetList(prefix, name string) ([]string, error) GetString(prefix, name string) (string, error) SetString(prefix, name string, val string) error + GetTime(prefix, name string) (time.Time, error) + SetTime(prefix, name string, val time.Time) error GetInt(prefix, name string) (int, error) SetInt(prefix, name string, val int) error } diff --git a/storage/cache/database_test.go b/storage/cache/database_test.go index ef4468a02..7735d03a1 100644 --- a/storage/cache/database_test.go +++ b/storage/cache/database_test.go @@ -15,37 +15,39 @@ package cache import ( "testing" + "time" "github.com/stretchr/testify/assert" ) func testMeta(t *testing.T, db Database) { // Set meta string - if err := db.SetString("meta", "1", "2"); err != nil { - t.Fatal(err) - } + err := db.SetString("meta", "1", "2") + assert.Nil(t, err) // Get meta string value, err := db.GetString("meta", "1") - if err != nil { - t.Fatal(err) - } + assert.Nil(t, err) assert.Equal(t, "2", value) // Get meta not existed value, err = db.GetString("meta", "NULL") - if err == nil { - t.Fatal(err) - } + assert.ErrorIs(t, err, ErrObjectNotExist) assert.Equal(t, "", value) // Set meta int - if err = db.SetInt("meta", "1", 2); err != nil { - t.Fatal(err) - } + err = db.SetInt("meta", "1", 2) + assert.Nil(t, err) // Get meta int - if value, err := db.GetInt("meta", "1"); err != nil { - t.Fatal(err) - } else { - assert.Equal(t, 2, value) - } + valInt, err := db.GetInt("meta", "1") + assert.Nil(t, err) + assert.Equal(t, 2, valInt) + // set meta time + err = db.SetTime("meta", "1", time.Date(1996, 4, 8, 0, 0, 0, 0, time.UTC)) + assert.Nil(t, err) + // get meta time + valTime, err := db.GetTime("meta", "1") + assert.Nil(t, err) + assert.Equal(t, 1996, valTime.Year()) + assert.Equal(t, time.Month(4), valTime.Month()) + assert.Equal(t, 8, valTime.Day()) } func testScores(t *testing.T, db Database) { diff --git a/storage/cache/no_database.go b/storage/cache/no_database.go index 2edb3857c..6b55a5032 100644 --- a/storage/cache/no_database.go +++ b/storage/cache/no_database.go @@ -14,6 +14,8 @@ package cache +import "time" + // NoDatabase means no database used for cache. type NoDatabase struct{} @@ -66,3 +68,13 @@ func (NoDatabase) GetInt(prefix, name string) (int, error) { func (NoDatabase) SetInt(prefix, name string, val int) error { return ErrNoDatabase } + +// GetTime method of NoDatabase returns ErrNoDatabase. +func (NoDatabase) GetTime(prefix, name string) (time.Time, error) { + return time.Time{}, ErrNoDatabase +} + +// SetTime method of NoDatabase returns ErrNoDatabase. +func (NoDatabase) SetTime(prefix, name string, val time.Time) error { + return ErrNoDatabase +} diff --git a/storage/cache/redis.go b/storage/cache/redis.go index 2492c9115..8c67ac66e 100644 --- a/storage/cache/redis.go +++ b/storage/cache/redis.go @@ -17,7 +17,9 @@ package cache import ( "context" "encoding/json" + "github.com/araddon/dateparse" "strconv" + "time" "github.com/go-redis/redis/v8" ) @@ -28,15 +30,15 @@ type Redis struct { } // Close redis connection. -func (redis *Redis) Close() error { - return redis.client.Close() +func (r *Redis) Close() error { + return r.client.Close() } -// SetScores save a list of scored items to redis. -func (redis *Redis) SetScores(prefix, name string, items []ScoredItem) error { +// SetScores save a list of scored items to Redis. +func (r *Redis) SetScores(prefix, name string, items []ScoredItem) error { var ctx = context.Background() key := prefix + "/" + name - err := redis.client.Del(ctx, key).Err() + err := r.client.Del(ctx, key).Err() if err != nil { return err } @@ -45,7 +47,7 @@ func (redis *Redis) SetScores(prefix, name string, items []ScoredItem) error { if err != nil { return err } - err = redis.client.RPush(ctx, key, data).Err() + err = r.client.RPush(ctx, key, data).Err() if err != nil { return err } @@ -53,12 +55,12 @@ func (redis *Redis) SetScores(prefix, name string, items []ScoredItem) error { return nil } -// GetScores returns a list of scored items from redis. -func (redis *Redis) GetScores(prefix, name string, begin, end int) ([]ScoredItem, error) { +// GetScores returns a list of scored items from Redis. +func (r *Redis) GetScores(prefix, name string, begin, end int) ([]ScoredItem, error) { var ctx = context.Background() key := prefix + "/" + name res := make([]ScoredItem, 0) - data, err := redis.client.LRange(ctx, key, int64(begin), int64(end)).Result() + data, err := r.client.LRange(ctx, key, int64(begin), int64(end)).Result() if err != nil { return nil, err } @@ -73,19 +75,19 @@ func (redis *Redis) GetScores(prefix, name string, begin, end int) ([]ScoredItem return res, err } -// ClearList clears a list of items in redis. -func (redis *Redis) ClearList(prefix, name string) error { +// ClearList clears a list of items in Redis. +func (r *Redis) ClearList(prefix, name string) error { var ctx = context.Background() key := prefix + "/" + name - return redis.client.Del(ctx, key).Err() + return r.client.Del(ctx, key).Err() } -// AppendList appends a list of scored items to redis. -func (redis *Redis) AppendList(prefix, name string, items ...string) error { +// AppendList appends a list of scored items to Redis. +func (r *Redis) AppendList(prefix, name string, items ...string) error { var ctx = context.Background() key := prefix + "/" + name for _, item := range items { - err := redis.client.RPush(ctx, key, item).Err() + err := r.client.RPush(ctx, key, item).Err() if err != nil { return err } @@ -93,12 +95,12 @@ func (redis *Redis) AppendList(prefix, name string, items ...string) error { return nil } -// GetList returns a list of scored items from redis. -func (redis *Redis) GetList(prefix, name string) ([]string, error) { +// GetList returns a list of scored items from Redis. +func (r *Redis) GetList(prefix, name string) ([]string, error) { var ctx = context.Background() key := prefix + "/" + name res := make([]string, 0) - data, err := redis.client.LRange(ctx, key, 0, -1).Result() + data, err := r.client.LRange(ctx, key, 0, -1).Result() if err != nil { return nil, err } @@ -108,13 +110,13 @@ func (redis *Redis) GetList(prefix, name string) ([]string, error) { return res, err } -// GetString returns a string from redis. -func (redis *Redis) GetString(prefix, name string) (string, error) { +// GetString returns a string from Redis. +func (r *Redis) GetString(prefix, name string) (string, error) { var ctx = context.Background() key := prefix + "/" + name - val, err := redis.client.Get(ctx, key).Result() + val, err := r.client.Get(ctx, key).Result() if err != nil { - if err.Error() == "redis: nil" { + if err == redis.Nil { return "", ErrObjectNotExist } return "", err @@ -122,30 +124,48 @@ func (redis *Redis) GetString(prefix, name string) (string, error) { return val, err } -// SetString saves a string to redis. -func (redis *Redis) SetString(prefix, name string, val string) error { +// SetString saves a string to Redis. +func (r *Redis) SetString(prefix, name string, val string) error { var ctx = context.Background() key := prefix + "/" + name - if err := redis.client.Set(ctx, key, val, 0).Err(); err != nil { + if err := r.client.Set(ctx, key, val, 0).Err(); err != nil { return err } return nil } -// GetInt returns a integer from redis. -func (redis *Redis) GetInt(prefix, name string) (int, error) { - val, err := redis.GetString(prefix, name) +// GetInt returns a integer from Redis. +func (r *Redis) GetInt(prefix, name string) (int, error) { + val, err := r.GetString(prefix, name) if err != nil { - return -1, nil + return 0, nil } buf, err := strconv.Atoi(val) if err != nil { - return -1, err + return 0, err } return buf, err } -// SetInt saves a integer from redis. -func (redis *Redis) SetInt(prefix, name string, val int) error { - return redis.SetString(prefix, name, strconv.Itoa(val)) +// SetInt saves a integer from Redis. +func (r *Redis) SetInt(prefix, name string, val int) error { + return r.SetString(prefix, name, strconv.Itoa(val)) +} + +// GetTime returns a time from Redis. +func (r *Redis) GetTime(prefix, name string) (time.Time, error) { + val, err := r.GetString(prefix, name) + if err != nil { + return time.Time{}, nil + } + tm, err := dateparse.ParseAny(val) + if err != nil { + return time.Time{}, nil + } + return tm, nil +} + +// SetTime saves a time from Redis. +func (r *Redis) SetTime(prefix, name string, val time.Time) error { + return r.SetString(prefix, name, val.String()) } diff --git a/worker/local_cache.go b/worker/local_cache.go index b1138ec43..6dbababb6 100644 --- a/worker/local_cache.go +++ b/worker/local_cache.go @@ -20,11 +20,13 @@ import ( "path/filepath" ) +// LocalCache for the worker node. type LocalCache struct { path string WorkerName string } +// LoadLocalCache loads cache from a local file. func LoadLocalCache(path string) (*LocalCache, error) { state := &LocalCache{path: path} // check if file exists @@ -47,6 +49,7 @@ func LoadLocalCache(path string) (*LocalCache, error) { return state, nil } +// WriteLocalCache writes cache to a local file. func (c *LocalCache) WriteLocalCache() error { // create parent folder if not exists parent := filepath.Dir(c.path) diff --git a/worker/worker.go b/worker/worker.go index 28d66e00c..34524d49f 100644 --- a/worker/worker.go +++ b/worker/worker.go @@ -20,7 +20,6 @@ import ( "encoding/gob" "encoding/json" "fmt" - "github.com/araddon/dateparse" "github.com/prometheus/client_golang/prometheus/promhttp" "github.com/scylladb/go-set" "github.com/zhenghaoz/gorse/model/ranking" @@ -39,10 +38,11 @@ import ( "google.golang.org/grpc" ) +// Worker manages states of a worker node. type Worker struct { // worker config cfg *config.Config - Jobs int + jobs int workerName string httpHost string httpPort int @@ -50,23 +50,23 @@ type Worker struct { masterPort int // database connection - cacheAddress string - cacheStore cache.Database - dataAddress string - dataStore data.Database + cachePath string + cacheClient cache.Database + dataPath string + dataClient data.Database // master connection - MasterClient protocol.MasterClient + masterClient protocol.MasterClient // user index - latestUserVersion int64 - userVersion int64 - userIndex *base.MapIndex + latestUserIndexVersion int64 + currentUserIndexVersion int64 + userIndex *base.MapIndex - // collaborative filtering model - latestPRVersion int64 - prModelVersion int64 - prModel ranking.Model + // ranking model + latestRankingModelVersion int64 + currentRankingModelVersion int64 + rankingModel ranking.Model // peers peers []string @@ -75,25 +75,26 @@ type Worker struct { // events ticker *time.Ticker syncedChan chan bool // meta synced events - pullChan chan bool // model pulled events + pulledChan chan bool // model pulled events } +// NewWorker creates a new worker node. func NewWorker(masterHost string, masterPort int, httpHost string, httpPort int, jobs int) *Worker { return &Worker{ // database - dataStore: data.NoDatabase{}, - cacheStore: cache.NoDatabase{}, + dataClient: data.NoDatabase{}, + cacheClient: cache.NoDatabase{}, // config masterHost: masterHost, masterPort: masterPort, httpHost: httpHost, httpPort: httpPort, - Jobs: jobs, + jobs: jobs, cfg: (*config.Config)(nil).LoadDefaultIfNil(), // events ticker: time.NewTicker(time.Minute), syncedChan: make(chan bool, 1024), - pullChan: make(chan bool, 1024), + pulledChan: make(chan bool, 1024), } } @@ -104,7 +105,7 @@ func (w *Worker) Sync() { for { var meta *protocol.Meta var err error - if meta, err = w.MasterClient.GetMeta(context.Background(), + if meta, err = w.masterClient.GetMeta(context.Background(), &protocol.NodeInfo{ NodeType: protocol.NodeType_WorkerNode, NodeName: w.workerName, @@ -122,40 +123,40 @@ func (w *Worker) Sync() { } // connect to data store - if w.dataAddress != w.cfg.Database.DataStore { + if w.dataPath != w.cfg.Database.DataStore { base.Logger().Info("connect data store", zap.String("database", w.cfg.Database.DataStore)) - if w.dataStore, err = data.Open(w.cfg.Database.DataStore); err != nil { + if w.dataClient, err = data.Open(w.cfg.Database.DataStore); err != nil { base.Logger().Error("failed to connect data store", zap.Error(err)) goto sleep } - w.dataAddress = w.cfg.Database.DataStore + w.dataPath = w.cfg.Database.DataStore } // connect to cache store - if w.cacheAddress != w.cfg.Database.CacheStore { + if w.cachePath != w.cfg.Database.CacheStore { base.Logger().Info("connect cache store", zap.String("database", w.cfg.Database.CacheStore)) - if w.cacheStore, err = cache.Open(w.cfg.Database.CacheStore); err != nil { + if w.cacheClient, err = cache.Open(w.cfg.Database.CacheStore); err != nil { base.Logger().Error("failed to connect cache store", zap.Error(err)) goto sleep } - w.cacheAddress = w.cfg.Database.CacheStore + w.cachePath = w.cfg.Database.CacheStore } - // check CF version - w.latestPRVersion = meta.PrVersion - if w.latestPRVersion != w.prModelVersion { - base.Logger().Info("new personal ranking model found", - zap.String("old_version", base.Hex(w.prModelVersion)), - zap.String("new_version", base.Hex(w.latestPRVersion))) + // check ranking model version + w.latestRankingModelVersion = meta.RankingModelVersion + if w.latestRankingModelVersion != w.currentRankingModelVersion { + base.Logger().Info("new ranking model found", + zap.String("old_version", base.Hex(w.currentRankingModelVersion)), + zap.String("new_version", base.Hex(w.latestRankingModelVersion))) w.syncedChan <- true } // check user index version - w.latestUserVersion = meta.UserIndexVersion - if w.latestUserVersion != w.userVersion { + w.latestUserIndexVersion = meta.UserIndexVersion + if w.latestUserIndexVersion != w.currentUserIndexVersion { base.Logger().Info("new user index found", - zap.String("old_version", base.Hex(w.userVersion)), - zap.String("new_version", base.Hex(w.latestUserVersion))) + zap.String("old_version", base.Hex(w.currentUserIndexVersion)), + zap.String("new_version", base.Hex(w.latestUserIndexVersion))) w.syncedChan <- true } @@ -166,16 +167,16 @@ func (w *Worker) Sync() { } } -// Pull user index and collaborative filtering model from master. +// Pull user index and ranking model from master. func (w *Worker) Pull() { defer base.CheckPanic() for range w.syncedChan { pulled := false // pull user index - if w.latestUserVersion != w.userVersion { + if w.latestUserIndexVersion != w.currentUserIndexVersion { base.Logger().Info("start pull user index") - if userIndexResponse, err := w.MasterClient.GetUserIndex(context.Background(), + if userIndexResponse, err := w.masterClient.GetUserIndex(context.Background(), &protocol.NodeInfo{NodeType: protocol.NodeType_WorkerNode, NodeName: w.workerName}, grpc.MaxCallRecvMsgSize(10e8)); err != nil { base.Logger().Error("failed to pull user index", zap.Error(err)) @@ -188,31 +189,31 @@ func (w *Worker) Pull() { base.Logger().Error("failed to decode user index", zap.Error(err)) } else { w.userIndex = &userIndex - w.userVersion = userIndexResponse.Version + w.currentUserIndexVersion = userIndexResponse.Version base.Logger().Info("synced user index", - zap.String("version", base.Hex(w.userVersion))) + zap.String("version", base.Hex(w.currentUserIndexVersion))) pulled = true } } } - // pull personal ranking model - if w.latestPRVersion != w.prModelVersion { - base.Logger().Info("start pull personal ranking model") - if mfResponse, err := w.MasterClient.GetPRModel(context.Background(), + // pull ranking model + if w.latestRankingModelVersion != w.currentRankingModelVersion { + base.Logger().Info("start pull ranking model") + if mfResponse, err := w.masterClient.GetRankingModel(context.Background(), &protocol.NodeInfo{ NodeType: protocol.NodeType_WorkerNode, NodeName: w.workerName, }, grpc.MaxCallRecvMsgSize(10e8)); err != nil { - base.Logger().Error("failed to pull personal ranking model", zap.Error(err)) + base.Logger().Error("failed to pull ranking model", zap.Error(err)) } else { - w.prModel, err = ranking.DecodeModel(mfResponse.Name, mfResponse.Model) + w.rankingModel, err = ranking.DecodeModel(mfResponse.Name, mfResponse.Model) if err != nil { - base.Logger().Error("failed to decode personal ranking model", zap.Error(err)) + base.Logger().Error("failed to decode ranking model", zap.Error(err)) } else { - w.prModelVersion = mfResponse.Version - base.Logger().Info("synced personal ranking model", - zap.String("version", base.Hex(w.prModelVersion))) + w.currentRankingModelVersion = mfResponse.Version + base.Logger().Info("synced ranking model", + zap.String("version", base.Hex(w.currentRankingModelVersion))) pulled = true } } @@ -224,6 +225,7 @@ func (w *Worker) Pull() { } } +// ServeMetrics serves Prometheus metrics. func (w *Worker) ServeMetrics() { http.Handle("/metrics", promhttp.Handler()) err := http.ListenAndServe(fmt.Sprintf("%s:%d", w.httpHost, w.httpPort), nil) @@ -232,6 +234,7 @@ func (w *Worker) ServeMetrics() { } } +// Serve as a worker node. func (w *Worker) Serve() { rand.Seed(time.Now().UTC().UnixNano()) // open local store @@ -249,7 +252,7 @@ func (w *Worker) Serve() { } w.workerName = state.WorkerName base.Logger().Info("start worker", - zap.Int("n_jobs", w.Jobs), + zap.Int("n_jobs", w.jobs), zap.String("worker_name", w.workerName)) // connect to master @@ -257,7 +260,7 @@ func (w *Worker) Serve() { if err != nil { base.Logger().Fatal("failed to connect master", zap.Error(err)) } - w.MasterClient = protocol.NewMasterClient(conn) + w.masterClient = protocol.NewMasterClient(conn) go w.Sync() go w.Pull() @@ -276,11 +279,11 @@ func (w *Worker) Serve() { return } - // offline recommendation - if w.prModel != nil { - w.Recommend(w.prModel, workingUsers) + // recommendation + if w.rankingModel != nil { + w.Recommend(w.rankingModel, workingUsers) } else { - base.Logger().Debug("local personal ranking model doesn't exist") + base.Logger().Debug("local ranking model doesn't exist") } } } @@ -295,6 +298,13 @@ func (w *Worker) Serve() { } } +// Recommend items to users. The workflow of recommendation is: +// 1. Skip inactive users. +// 2. Load historical items. +// 3. Load positive items if KNN used. +// 4. Generate recommendation. +// 5. Save result. +// 6. Refresh cache. func (w *Worker) Recommend(m ranking.Model, users []string) { var userIndexer base.Index // load user index @@ -303,10 +313,10 @@ func (w *Worker) Recommend(m ranking.Model, users []string) { } // load item index itemIds := m.GetItemIndex().GetNames() - base.Logger().Info("personal ranking recommendation", + base.Logger().Info("ranking recommendation", zap.Int("n_working_users", len(users)), zap.Int("n_items", len(itemIds)), - zap.Int("n_jobs", w.Jobs), + zap.Int("n_jobs", w.jobs), zap.Int("cache_size", w.cfg.Database.CacheSize)) // progress tracker completed := make(chan interface{}, 1000) @@ -322,15 +332,15 @@ func (w *Worker) Recommend(m ranking.Model, users []string) { } completedCount++ case <-ticker.C: - base.Logger().Info("personal ranking recommendation", + base.Logger().Info("ranking recommendation", zap.Int("n_complete_users", completedCount), zap.Int("n_working_users", len(users))) } } }() - // collaborative filtering recommendation + // recommendation startTime := time.Now() - _ = base.Parallel(len(users), w.Jobs, func(workerId, jobId int) error { + _ = base.Parallel(len(users), w.jobs, func(workerId, jobId int) error { userId := users[jobId] // convert to user index var userIndex int @@ -341,23 +351,18 @@ func (w *Worker) Recommend(m ranking.Model, users []string) { if !w.checkRecommendCacheTimeout(userId) { return nil } - // Clear ignore items in cache. Since ignore items have been ignored - // in offline recommendation stage. - err := w.cacheStore.ClearList(cache.IgnoreItems, userId) - if err != nil { - return err - } - // remove saw items - historyItems, err := loadFeedbackItems(w.dataStore, userId) + // load historical items + historyItems, err := loadUserHistoricalItems(w.dataClient, userId) historySet := set.NewStringSet(historyItems...) if err != nil { base.Logger().Error("failed to pull user feedback", zap.String("user_id", userId), zap.Error(err)) return err } - var favoredItemIndices []int + // load positive items + var positiveItemIndices []int if _, ok := m.(*ranking.KNN); ok { - favoredItems, err := loadFeedbackItems(w.dataStore, userId, w.cfg.Database.PositiveFeedbackType...) + favoredItems, err := loadUserHistoricalItems(w.dataClient, userId, w.cfg.Database.PositiveFeedbackType...) if err != nil { base.Logger().Error("failed to pull user feedback", zap.String("user_id", userId), zap.Error(err)) @@ -366,10 +371,11 @@ func (w *Worker) Recommend(m ranking.Model, users []string) { for _, itemId := range favoredItems { itemIndex := m.GetItemIndex().ToNumber(itemId) if itemIndex != base.NotId { - favoredItemIndices = append(favoredItemIndices, itemIndex) + positiveItemIndices = append(positiveItemIndices, itemIndex) } } } + // generate recommendation recItems := base.NewTopKStringFilter(w.cfg.Database.CacheSize) for itemIndex, itemId := range itemIds { if !historySet.Has(itemId) { @@ -377,25 +383,31 @@ func (w *Worker) Recommend(m ranking.Model, users []string) { case ranking.MatrixFactorization: recItems.Push(itemId, m.(ranking.MatrixFactorization).InternalPredict(userIndex, itemIndex)) case *ranking.KNN: - recItems.Push(itemId, m.(*ranking.KNN).InternalPredict(favoredItemIndices, itemIndex)) + recItems.Push(itemId, m.(*ranking.KNN).InternalPredict(positiveItemIndices, itemIndex)) default: base.Logger().Error("unknown model type") } } } + // save result elems, scores := recItems.PopAll() - if err = w.cacheStore.SetScores(cache.CollaborativeItems, userId, cache.CreateScoredItems(elems, scores)); err != nil { - base.Logger().Error("failed to cache collaborative filtering recommendation", zap.Error(err)) + if err = w.cacheClient.SetScores(cache.CollaborativeItems, userId, cache.CreateScoredItems(elems, scores)); err != nil { + base.Logger().Error("failed to cache recommendation", zap.Error(err)) return err } - if err = w.cacheStore.SetString(cache.LastUpdateRecommendTime, userId, base.Now()); err != nil { - base.Logger().Error("failed to cache collaborative filtering recommendation time", zap.Error(err)) + if err = w.cacheClient.SetString(cache.LastUpdateRecommendTime, userId, base.Now()); err != nil { + base.Logger().Error("failed to cache recommendation time", zap.Error(err)) + } + // refresh cache + err = w.refreshCache(userId) + if err != nil { + return err } completed <- nil return nil }) close(completed) - base.Logger().Info("complete personal ranking recommendation", + base.Logger().Info("complete ranking recommendation", zap.String("used_time", time.Since(startTime).String())) } @@ -404,31 +416,20 @@ func (w *Worker) Recommend(m ranking.Model, users []string) { // 2. if recommend time + timeout < now, stale. func (w *Worker) checkRecommendCacheTimeout(userId string) bool { var activeTime, recommendTime time.Time + var err error // read active time - activeTimeLiteral, err := w.cacheStore.GetString(cache.LastActiveTime, userId) - if err != nil { - if err != cache.ErrObjectNotExist { - base.Logger().Error("failed to read meta", zap.Error(err)) - } - } else { - activeTime, err = dateparse.ParseAny(activeTimeLiteral) - if err != nil { - base.Logger().Error("failed to time", zap.Error(err)) - } + activeTime, err = w.cacheClient.GetTime(cache.LastActiveTime, userId) + if err != nil && err != cache.ErrObjectNotExist { + base.Logger().Error("failed to read meta", zap.Error(err)) } // read recommend time - recommendTimeLiteral, err := w.cacheStore.GetString(cache.LastUpdateRecommendTime, userId) + recommendTime, err = w.cacheClient.GetTime(cache.LastUpdateRecommendTime, userId) if err != nil { if err != cache.ErrObjectNotExist { base.Logger().Error("failed to read meta", zap.Error(err)) } else { return true } - } else { - recommendTime, err = dateparse.ParseAny(recommendTimeLiteral) - if err != nil { - base.Logger().Error("failed to time", zap.Error(err)) - } } // check time if activeTime.Unix() < recommendTime.Unix() { @@ -438,86 +439,51 @@ func (w *Worker) checkRecommendCacheTimeout(userId string) bool { return true } -// -//func (w *Worker) Subscribe(users []string) { -// base.Logger().Info("subscribe", -// zap.Bool("implicit_subscribe", w.cfg.Subscribe.ImplicitSubscribe)) -// completed := make(chan interface{}) -// go func() { -// defer base.CheckPanic() -// completedCount := 0 -// ticker := time.NewTicker(time.Second * 5) -// for { -// select { -// case _, ok := <-completed: -// if !ok { -// return -// } -// completedCount++ -// case <-ticker.C: -// base.Logger().Info("subscribe", -// zap.Int("n_complete_users", completedCount), -// zap.Int("n_working_users", len(users))) -// } -// } -// }() -// _ = base.Parallel(len(users), w.FitJobs, func(workerId, jobId int) error { -// user := users[jobId] -// // collect items -// historySet := base.NewStringSet() -// for _, feedbackType := range w.cfg.Database.PositiveFeedbackType { -// historyFeedback, err := w.dataStore.GetUserFeedback(user, &feedbackType) -// if err != nil { -// base.Logger().Error("failed to pull user feedback", -// zap.String("user_id", user), zap.Error(err)) -// return err -// } -// for _, feedback := range historyFeedback { -// historySet.Add(feedback.ItemId) -// } -// } -// // collect labels -// labelSet := make(map[string]int) -// for itemId, _ := range historySet { -// if item, err := w.dataStore.GetItem(itemId); err != nil { -// base.Logger().Error("failed to get item", zap.String("item_id", itemId), zap.Error(err)) -// } else { -// for _, label := range item.Labels { -// labelSet[label] ++ -// } -// } -// } -// base.Logger().Info("items", zap.Any("items", labelSet)) -// completed <- nil -// return nil -// }) -// close(completed) -//} - -func loadFeedbackItems(database data.Database, userId string, feedbackTypes ...string) ([]string, error) { +func loadUserHistoricalItems(database data.Database, userId string, feedbackTypes ...string) ([]string, error) { items := make([]string, 0) - if len(feedbackTypes) == 0 { - feedbacks, err := database.GetUserFeedback(userId) - if err != nil { - return nil, err - } - for _, feedback := range feedbacks { - items = append(items, feedback.ItemId) - } + feedbacks, err := database.GetUserFeedback(userId, feedbackTypes...) + if err != nil { + return nil, err + } + for _, feedback := range feedbacks { + items = append(items, feedback.ItemId) + } + return items, nil +} + +func (w *Worker) refreshCache(userId string) error { + var timeLimit *time.Time + // read recommend time + recommendTime, err := w.cacheClient.GetTime(cache.LastUpdateRecommendTime, userId) + if err == nil { + timeLimit = &recommendTime } else { - for _, tp := range feedbackTypes { - feedbacks, err := database.GetUserFeedback(userId, tp) - if err != nil { - return nil, err - } - for _, feedback := range feedbacks { - items = append(items, feedback.ItemId) - } + return err + } + // clear cache + err = w.cacheClient.ClearList(cache.IgnoreItems, userId) + if err != nil { + return err + } + // load cache + feedback, err := w.dataClient.GetUserFeedback(userId) + if err != nil { + return err + } + var items []string + for _, v := range feedback { + if v.Timestamp.Unix() > timeLimit.Unix() { + items = append(items, v.ItemId) } } - return items, nil + err = w.cacheClient.AppendList(cache.IgnoreItems, userId, items...) + if err != nil { + return err + } + return nil } +// split users between worker nodes. func split(userIndex base.Index, nodes []string, me string) ([]string, error) { // locate me pos := -1 diff --git a/worker/worker_test.go b/worker/worker_test.go index 6d0a6bed2..2d3029a68 100644 --- a/worker/worker_test.go +++ b/worker/worker_test.go @@ -12,3 +12,135 @@ // See the License for the specific language governing permissions and // limitations under the License. package worker + +import ( + "github.com/alicebob/miniredis/v2" + "github.com/stretchr/testify/assert" + "github.com/zhenghaoz/gorse/base" + "github.com/zhenghaoz/gorse/config" + "github.com/zhenghaoz/gorse/model" + "github.com/zhenghaoz/gorse/model/ranking" + "github.com/zhenghaoz/gorse/storage/cache" + "github.com/zhenghaoz/gorse/storage/data" + "strconv" + "testing" + "time" +) + +func TestSplit(t *testing.T) { + // create user index + userIndex := base.NewMapIndex() + userIndex.Add("1") + userIndex.Add("2") + userIndex.Add("3") + userIndex.Add("4") + userIndex.Add("5") + userIndex.Add("6") + userIndex.Add("7") + userIndex.Add("8") + // create nodes + nodes := []string{"a", "b", "c"} + + users, err := split(userIndex, nodes, "b") + assert.Nil(t, err) + assert.Equal(t, []string{"2", "5", "8"}, users) + + users, err = split(userIndex, nodes, "d") + assert.Error(t, err) +} + +type mockMatrixFactorizationForRecommend struct { + ranking.BaseMatrixFactorization +} + +func newMockMatrixFactorizationForRecommend(numUsers, numItems int) *mockMatrixFactorizationForRecommend { + m := new(mockMatrixFactorizationForRecommend) + m.UserIndex = base.NewMapIndex() + m.ItemIndex = base.NewMapIndex() + for i := 0; i < numUsers; i++ { + m.UserIndex.Add(strconv.Itoa(i)) + } + for i := 0; i < numItems; i++ { + m.ItemIndex.Add(strconv.Itoa(i)) + } + return m +} + +func (m *mockMatrixFactorizationForRecommend) InternalPredict(userId, itemId int) float32 { + return float32(itemId) +} + +func (m *mockMatrixFactorizationForRecommend) Clear() { + // do nothing +} + +func (m *mockMatrixFactorizationForRecommend) GetParamsGrid() model.ParamsGrid { + panic("don't call me") +} + +type mockWorker struct { + dataStoreServer *miniredis.Miniredis + cacheStoreServer *miniredis.Miniredis + Worker +} + +func newMockWorker(t *testing.T) *mockWorker { + w := new(mockWorker) + // create mock redis server + var err error + w.dataStoreServer, err = miniredis.Run() + assert.Nil(t, err) + w.cacheStoreServer, err = miniredis.Run() + assert.Nil(t, err) + // open database + w.dataClient, err = data.Open("redis://" + w.dataStoreServer.Addr()) + assert.Nil(t, err) + w.cacheClient, err = cache.Open("redis://" + w.cacheStoreServer.Addr()) + assert.Nil(t, err) + // configuration + w.cfg = (*config.Config)(nil).LoadDefaultIfNil() + w.jobs = 1 + return w +} + +func (w *mockWorker) Close(t *testing.T) { + err := w.dataClient.Close() + assert.Nil(t, err) + err = w.cacheClient.Close() + assert.Nil(t, err) + w.dataStoreServer.Close() + w.cacheStoreServer.Close() +} + +func TestRecommendMatrixFactorization(t *testing.T) { + // create mock worker + w := newMockWorker(t) + defer w.Close(t) + // insert feedbacks + now := time.Now() + err := w.dataClient.BatchInsertFeedback([]data.Feedback{ + {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "9"}, Timestamp: now.Add(-time.Hour)}, + {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "8"}, Timestamp: now.Add(time.Hour)}, + {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "7"}, Timestamp: now.Add(-time.Hour)}, + {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "6"}, Timestamp: now.Add(time.Hour)}, + {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "5"}, Timestamp: now.Add(-time.Hour)}, + {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "4"}, Timestamp: now.Add(time.Hour)}, + }, true, true) + assert.Nil(t, err) + // create mock model + m := newMockMatrixFactorizationForRecommend(1, 10) + w.Recommend(m, []string{"0"}) + + recommends, err := w.cacheClient.GetScores(cache.CollaborativeItems, "0", 0, -1) + assert.Nil(t, err) + assert.Equal(t, []cache.ScoredItem{ + {"3", 3}, + {"2", 2}, + {"1", 1}, + {"0", 0}, + }, recommends) + + read, err := w.cacheClient.GetList(cache.IgnoreItems, "0") + assert.Nil(t, err) + assert.Equal(t, []string{"4", "6", "8"}, read) +}