diff --git a/prometheus.go b/prometheus.go index 096f597..58c74a4 100644 --- a/prometheus.go +++ b/prometheus.go @@ -6,8 +6,10 @@ package prometheus import ( "context" + "errors" "net" "net/http" + "os" "time" "github.com/gozix/di" @@ -29,13 +31,18 @@ const ( // TagCollectorProvider is tag marks prometheus collectors. TagCollectorProvider = "prometheus.collector" + + // tagPrometheusFlagSet is tag marks bundle flag set. + tagPrometheusFlagSet = "prometheus.flag_set" + + // flagPrometheusPort is flag name. + flagPrometheusPort = "prometheus-port" ) type ( // Bundle implements the glue.Bundle interface. Bundle struct { - registry *prometheus.Registry - flagPortValue string + registry *prometheus.Registry } // Option interface. @@ -78,8 +85,12 @@ func (b *Bundle) Build(builder di.Builder) error { di.As(new(prometheus.Gatherer)), di.As(new(prometheus.Registerer)), ), - di.Provide(b.provideFlagSet, glue.AsPersistentFlags()), - di.Provide(b.providePreRunner, glue.AsPersistentPreRunner()), + di.Provide(b.provideFlagSet, glue.AsPersistentFlags(), di.Tags{{Name: tagPrometheusFlagSet}}), + di.Provide( + b.providePreRunner, + glue.AsPersistentPreRunner(), + di.Constraint(2, di.WithTags(tagPrometheusFlagSet)), + ), ) } @@ -90,11 +101,17 @@ func (b *Bundle) DependsOn() []string { } } -func (b *Bundle) provideFlagSet() *pflag.FlagSet { - var flagSet = pflag.NewFlagSet(BundleName, pflag.ExitOnError) - flagSet.StringVar(&b.flagPortValue, "prometheus-port", "", "prometheus metrics port") +func (b *Bundle) provideFlagSet() (*pflag.FlagSet, error) { + var flagSet = pflag.NewFlagSet(BundleName, pflag.ContinueOnError) + flagSet.String(flagPrometheusPort, "", "prometheus metrics port") + flagSet.ParseErrorsWhitelist.UnknownFlags = true - return flagSet + var err = flagSet.Parse(os.Args) + if errors.Is(err, pflag.ErrHelp) { + err = nil + } + + return flagSet, err } func (b *Bundle) provideRegistry(collectors []prometheus.Collector) (_ *prometheus.Registry, err error) { @@ -119,15 +136,14 @@ func (b *Bundle) provideRegistry(collectors []prometheus.Collector) (_ *promethe func (b *Bundle) providePreRunner( cfg *viper.Viper, logger *zap.Logger, + flagSet *pflag.FlagSet, registry *prometheus.Registry, -) (_ glue.PreRunner, _ func() error, err error) { +) (glue.PreRunner, func() error, error) { // use this hack for UnmarshalKey // see https://github.com/spf13/viper/issues/188 - var configPath = BundleName - var cfgPath = cfg.Sub(configPath) - if cfgPath != nil { - for _, key := range cfg.Sub(configPath).AllKeys() { - key = configPath + "." + key + if cfgPath := cfg.Sub(BundleName); cfgPath != nil { + for _, key := range cfg.Sub(BundleName).AllKeys() { + key = BundleName + "." + key cfg.Set(key, cfg.Get(key)) } } @@ -138,12 +154,12 @@ func (b *Bundle) providePreRunner( Path string `mapstructure:"path"` }{} - if err = cfg.UnmarshalKey(configPath, &conf); err != nil { + if err := cfg.UnmarshalKey(BundleName, &conf); err != nil { return nil, nil, err } - if b.flagPortValue != "" { - conf.Port = b.flagPortValue + if flag := flagSet.Lookup(flagPrometheusPort); flag != nil && flag.Value.String() != "" { + conf.Port = flag.Value.String() } if conf.Path == "" { @@ -155,22 +171,37 @@ func (b *Bundle) providePreRunner( registry, promhttp.HandlerFor(registry, promhttp.HandlerOpts{}), )) - var srv = &http.Server{ - Addr: net.JoinHostPort(conf.Host, conf.Port), - Handler: mux, - } + var ( + srv = &http.Server{ + Addr: net.JoinHostPort(conf.Host, conf.Port), + Handler: mux, + } + log = logger.With(zap.String("addr", srv.Addr)) + ) var preRunner = glue.PreRunnerFunc(func(ctx context.Context) error { + log.Info("Starting prometheus HTTP server") + go func() { - if err = srv.ListenAndServe(); err != nil && err != http.ErrServerClosed { - logger.Panic("Error occurred during serve prometheus http handler", zap.Error(err)) + var err = srv.ListenAndServe() + switch { + case errors.Is(err, http.ErrServerClosed): + log.Info("Gracefully shutting down the HTTP server") + + case err != nil: + log.Error("Error occurred during serve prometheus http server", zap.Error(err)) } }() + log.Info("Prometheus HTTP server started") + return nil }) var closer = func() error { + var timeout = 10 * time.Second + log.Info("Stopping prometheus HTTP server", zap.Duration("timeout", timeout)) + var ctx, cancelFunc = context.WithTimeout(context.Background(), 10*time.Second) defer cancelFunc()