diff --git a/README.md b/README.md index 3cf3687..6c64b1e 100644 --- a/README.md +++ b/README.md @@ -9,13 +9,18 @@ ```json { "redis": { - "host": "127.0.0.1", - "port": "6379", - "db": 0, - "max_retiries": 2, - "read_timeout": "2s", - "write_timeout": "2s", - "idle_timeout": "1m" + "default": { + "host": "127.0.0.1", + "port": "6379", + "db": 0, + "password": "somepassword", + "max_retiries": 2, + "read_timeout": "2s", + "write_timeout": "2s", + "idle_timeout": "1m" + } } } ``` +"password" field is optional and ignored if empty +"db" field is optional. Default is 0 \ No newline at end of file diff --git a/goredis.go b/goredis.go index f9083a1..8823c88 100644 --- a/goredis.go +++ b/goredis.go @@ -3,8 +3,8 @@ package goredis import ( "errors" - "net" - "time" + "fmt" + "strings" "github.com/go-redis/redis" "github.com/gozix/viper/v2" @@ -17,13 +17,6 @@ type ( // Pool is type alias of redis.Client Pool = redis.Client - - // redisConf is logger configuration struct. - redisConf struct { - IdleTimeout time.Duration `mapstructure:"idle_timeout"` - ReadTimeout time.Duration `mapstructure:"read_timeout"` - WriteTimeout time.Duration `mapstructure:"write_timeout"` - } ) // BundleName is default definition name. @@ -44,41 +37,59 @@ func (b *Bundle) Build(builder *di.Builder) error { return builder.Add(di.Def{ Name: BundleName, Build: func(ctn di.Container) (_ interface{}, err error) { - var cnf *viper.Viper - if err = ctn.Fill(viper.BundleName, &cnf); err != nil { + var cfg *viper.Viper + if err = ctn.Fill(viper.BundleName, &cfg); err != nil { return nil, errors.New("can't get config from container") } - var conf redisConf - if err = cnf.UnmarshalKey("redis", &conf); err != nil { - return nil, err - } + // use this is hack, not UnmarshalKey + // see https://github.com/spf13/viper/issues/188 + var ( + keys = cfg.Sub(configKey).AllKeys() + conf = make(Configs, len(keys)) + ) - options := &redis.Options{ - Addr: net.JoinHostPort( - cnf.GetString("redis.host"), - cnf.GetString("redis.port"), - ), - DB: cnf.GetInt("redis.db"), - MaxRetries: cnf.GetInt("redis.max_retries"), - IdleTimeout: conf.IdleTimeout, - ReadTimeout: conf.ReadTimeout, - WriteTimeout: conf.WriteTimeout, - } + for _, key := range keys { + var name = strings.Split(key, ".")[0] + if _, ok := conf[name]; ok { + continue + } - var client *redis.Client - if client = redis.NewClient(options); client == nil { - return nil, err - } + var suffix = fmt.Sprintf("%s.%s.", configKey, name) + + cfg.SetDefault(suffix+"port", "6379") + + var c = Config{ + Host: cfg.GetString(suffix + "host"), + Port: cfg.GetString(suffix + "port"), + DB: cfg.GetInt(suffix + "db"), + Password: cfg.GetString(suffix + "password"), + MaxRetries: cfg.GetInt(suffix + "max_retries"), + IdleTimeout: cfg.GetDuration(suffix + "idle_timeout"), + ReadTimeout: cfg.GetDuration(suffix + "read_timeout"), + WriteTimeout: cfg.GetDuration(suffix + "write_timeout"), + } + + // validating + if c.Host == "" { + return nil, errors.New(suffix + "host should be set") + } + + if c.DB < 0 { + return nil, errors.New(suffix + "db should be greater or equal to 0") + } + + if c.MaxRetries < 0 { + return nil, errors.New(suffix + "max_retries should be greater or equal to 0") + } - if _, err = client.Ping().Result(); err != nil { - return nil, err + conf[name] = c } - return client, nil + return NewRegistry(conf), nil }, Close: func(obj interface{}) error { - return obj.(*redis.Client).Close() + return obj.(*Registry).Close() }, }) } diff --git a/registry.go b/registry.go new file mode 100644 index 0000000..af4559c --- /dev/null +++ b/registry.go @@ -0,0 +1,115 @@ +// Package goredis provides implementation of go-redis client. +package goredis + +import ( + "errors" + "net" + "sync" + "time" + + "github.com/go-redis/redis" +) + +// DEFAULT is default connection name. +const DEFAULT = "default" + +// ConfigKey is root config key. +const configKey = "redis" + +type ( + // Config is registry configuration item. + Config struct { + Host string `json:"host"` + Port string `json:"port"` + DB int `json:"db"` + Password string `json:"password"` + MaxRetries int `json:"max_retries"` + IdleTimeout time.Duration `json:"idle_timeout"` + ReadTimeout time.Duration `json:"read_timeout"` + WriteTimeout time.Duration `json:"write_timeout"` + } + + // Configs is registry configurations. + Configs map[string]Config + + // Registry is database connection registry. + Registry struct { + mux sync.Mutex + clients map[string]*redis.Client + conf Configs + } +) + +var ( + // ErrUnknownConnection is error triggered when connection with provided name not founded. + ErrUnknownConnection = errors.New("unknown connection") +) + +// NewRegistry is registry constructor. +func NewRegistry(conf Configs) *Registry { + return &Registry{ + clients: make(map[string]*redis.Client, 1), + conf: conf, + } +} + +// Close is method for close connections. +func (r *Registry) Close() (err error) { + r.mux.Lock() + defer r.mux.Unlock() + + for key, client := range r.clients { + if errClose := client.Close(); errClose != nil { + err = errClose + } + + delete(r.clients, key) + } + + return err +} + +// Connection is default connection getter. +func (r *Registry) Connection() (*redis.Client, error) { + return r.ConnectionWithName(DEFAULT) +} + +// ConnectionWithName is connection getter by name. +func (r *Registry) ConnectionWithName(name string) (_ *redis.Client, err error) { + r.mux.Lock() + defer r.mux.Unlock() + + var client, initialized = r.clients[name] + if initialized { + return client, nil + } + + var cfg, exists = r.conf[name] + if !exists { + return nil, ErrUnknownConnection + } + + var options = &redis.Options{ + Addr: net.JoinHostPort( + cfg.Host, + cfg.Port, + ), + DB: cfg.DB, + MaxRetries: cfg.MaxRetries, + IdleTimeout: cfg.IdleTimeout, + ReadTimeout: cfg.ReadTimeout, + WriteTimeout: cfg.WriteTimeout, + } + + if client = redis.NewClient(options); client == nil { + return nil, err + } + + if _, err = client.Ping().Result(); err != nil { + return nil, err + } + + r.clients[name] = client + + return client, nil +}