Skip to content

Commit

Permalink
endpointaccessible: check if endpoint parameters changed at every sync
Browse files Browse the repository at this point in the history
If there are no changes in the endpoint parameters, skip the check.
  • Loading branch information
liouk committed Jan 25, 2024
1 parent 594a160 commit 91360bd
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 30 deletions.
65 changes: 46 additions & 19 deletions pkg/libs/endpointaccessible/endpoint_accessible_controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package endpointaccessible
import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"net/http"
"strings"
Expand All @@ -11,6 +12,7 @@ import (

apierrors "k8s.io/apimachinery/pkg/api/errors"
utilerrors "k8s.io/apimachinery/pkg/util/errors"
"k8s.io/apimachinery/pkg/util/sets"

operatorv1 "github.com/openshift/api/operator/v1"
"github.com/openshift/library-go/pkg/controller/factory"
Expand All @@ -23,6 +25,12 @@ type endpointAccessibleController struct {
endpointListFn EndpointListFunc
getTLSConfigFn EndpointTLSConfigFunc
availableConditionName string

maxCheckLatency time.Duration
lastCheckTime time.Time
lastEndpoints sets.Set[string]
lastServerName string
lastCA *x509.CertPool
}

type EndpointListFunc func() ([]string, error)
Expand All @@ -47,6 +55,8 @@ func NewEndpointAccessibleController(
endpointListFn: endpointListFn,
getTLSConfigFn: getTLSConfigFn,
availableConditionName: name + "EndpointAccessibleControllerAvailable",
maxCheckLatency: resyncInterval - 5*time.Second,
lastEndpoints: sets.New[string](),
}

return factory.New().
Expand All @@ -70,26 +80,41 @@ func humanizeError(err error) error {

func (c *endpointAccessibleController) sync(ctx context.Context, syncCtx factory.SyncContext) error {
endpoints, err := c.endpointListFn()
if err != nil {
if apierrors.IsNotFound(err) {
_, _, statusErr := v1helpers.UpdateStatus(ctx, c.operatorClient, v1helpers.UpdateConditionFn(
operatorv1.OperatorCondition{
Type: c.availableConditionName,
Status: operatorv1.ConditionFalse,
Reason: "ResourceNotFound",
Message: err.Error(),
}))

return statusErr
}
if apierrors.IsNotFound(err) {
_, _, statusErr := v1helpers.UpdateStatus(ctx, c.operatorClient, v1helpers.UpdateConditionFn(
operatorv1.OperatorCondition{
Type: c.availableConditionName,
Status: operatorv1.ConditionFalse,
Reason: "ResourceNotFound",
Message: err.Error(),
}))

return statusErr
} else if err != nil {
return err
}

newEndpoints := sets.New(endpoints...)
endpointsChanged := !c.lastEndpoints.Equal(newEndpoints)

tlsConfig, err := c.getTLSConfigFn()
if err != nil {
return err
}
tlsChanged := c.lastServerName != tlsConfig.ServerName || !tlsConfig.RootCAs.Equal(c.lastCA)

client, err := c.buildTLSClient()
isPastTimeForCheck := time.Since(c.lastCheckTime) > c.maxCheckLatency
if !endpointsChanged && !tlsChanged && !isPastTimeForCheck {
return nil
}
c.lastCheckTime = time.Now()
c.lastEndpoints = newEndpoints

client, err := c.buildTLSClient(tlsConfig)
if err != nil {
return err
}

// check all the endpoints in parallel. This matters for pods.
errCh := make(chan error, len(endpoints))
wg := sync.WaitGroup{}
Expand Down Expand Up @@ -155,20 +180,22 @@ func (c *endpointAccessibleController) sync(ctx context.Context, syncCtx factory
return utilerrors.NewAggregate(errors)
}

func (c *endpointAccessibleController) buildTLSClient() (*http.Client, error) {
func (c *endpointAccessibleController) buildTLSClient(tlsConfig *tls.Config) (*http.Client, error) {
transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: &tls.Config{
InsecureSkipVerify: true,
},
}
if c.getTLSConfigFn != nil {
tlsConfig, err := c.getTLSConfigFn()
if err != nil {
return nil, err
}

if tlsConfig != nil {
transport.TLSClientConfig = tlsConfig

// these are the fields that are set by our getTLSConfigFn funcs
c.lastServerName = tlsConfig.ServerName
c.lastCA = tlsConfig.RootCAs
}

return &http.Client{
Timeout: 5 * time.Second,
Transport: transport,
Expand Down
170 changes: 159 additions & 11 deletions pkg/libs/endpointaccessible/endpoint_accessible_controller_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,59 +2,207 @@ package endpointaccessible

import (
"context"
"crypto/tls"
"crypto/x509"
"fmt"
"testing"
"time"

operatorv1 "github.com/openshift/api/operator/v1"
"github.com/openshift/library-go/pkg/operator/v1helpers"
"k8s.io/apimachinery/pkg/util/sets"

"github.com/openshift/library-go/pkg/controller/factory"
"github.com/openshift/library-go/pkg/operator/events"
)

func Test_endpointAccessibleController_sync(t *testing.T) {
maxCheckLatency := 55 * time.Second

systemRootCAs, err := x509.SystemCertPool()
if err != nil {
t.Errorf("unexpected error when getting system cert pool: %v", err)
}

getTLSConfigFn := func(serverName string, returnErr error) func() (*tls.Config, error) {
return func() (*tls.Config, error) {
return &tls.Config{
RootCAs: systemRootCAs,
ServerName: serverName,
}, returnErr
}
}

getTLSConfigFnEmptyRootCAs := func(serverName string, returnErr error) func() (*tls.Config, error) {
return func() (*tls.Config, error) {
return &tls.Config{
RootCAs: x509.NewCertPool(),
ServerName: serverName,
}, returnErr
}
}

tests := []struct {
name string
endpointListFn EndpointListFunc
wantErr bool
name string
endpointListFn EndpointListFunc
getTLSConfigFn EndpointTLSConfigFunc
lastCheckTime time.Time
lastEndpoints sets.Set[string]
lastServerName string
lastCA *x509.CertPool
wantCheckExecuted bool
wantErr bool
}{
{
name: "all endpoints working",
name: "all endpoints working",
getTLSConfigFn: getTLSConfigFn("google.com", nil),
endpointListFn: func() ([]string, error) {
return []string{"https://google.com"}, nil
},
wantCheckExecuted: true,
},
{
name: "all endpoints working with tls config",
getTLSConfigFn: getTLSConfigFn("google.com", nil),
endpointListFn: func() ([]string, error) {
return []string{"https://google.com"}, nil
},
wantCheckExecuted: true,
},
{
name: "check working when endpoints change",
getTLSConfigFn: getTLSConfigFn("google.com", nil),
endpointListFn: func() ([]string, error) {
return []string{"https://google.com"}, nil
},
lastEndpoints: sets.New[string]("https://www.google.com"),
lastCheckTime: time.Now().Add(-1 * time.Second),
lastServerName: "google.com",
lastCA: systemRootCAs,
wantCheckExecuted: true,
},
{
name: "check working when check is due",
getTLSConfigFn: getTLSConfigFn("google.com", nil),
endpointListFn: func() ([]string, error) {
return []string{"https://google.com"}, nil
},
lastEndpoints: sets.New[string]("https://google.com"),
lastCheckTime: time.Now().Add(-2 * maxCheckLatency),
lastServerName: "google.com",
lastCA: systemRootCAs,
wantCheckExecuted: true,
},
{
name: "check working when tls server name changes",
getTLSConfigFn: getTLSConfigFn("google.com", nil),
endpointListFn: func() ([]string, error) {
return []string{"https://google.com"}, nil
},
lastEndpoints: sets.New[string]("https://google.com"),
lastCheckTime: time.Now().Add(-1 * time.Second),
lastServerName: "redhat.com",
lastCA: systemRootCAs,
wantCheckExecuted: true,
},
{
name: "check working when tls root CAs change",
getTLSConfigFn: getTLSConfigFn("google.com", nil),
endpointListFn: func() ([]string, error) {
return []string{"https://google.com"}, nil
},
lastEndpoints: sets.New[string]("https://google.com"),
lastCheckTime: time.Now().Add(-1 * time.Second),
lastServerName: "google.com",
lastCA: x509.NewCertPool(),
wantCheckExecuted: true,
},
{
name: "check skipped when no changes in parameters and check is not due",
getTLSConfigFn: getTLSConfigFn("google.com", nil),
endpointListFn: func() ([]string, error) {
return []string{"https://google.com"}, nil
},
lastEndpoints: sets.New[string]("https://google.com"),
lastCheckTime: time.Now().Add(-1 * time.Second),
lastServerName: "google.com",
lastCA: systemRootCAs,
wantCheckExecuted: false,
wantErr: false,
},
{
name: "check fails when tls config fails",
getTLSConfigFn: getTLSConfigFn("google.com", fmt.Errorf("tls config error")),
endpointListFn: func() ([]string, error) {
return []string{"https://google.com"}, nil
},
wantCheckExecuted: false,
wantErr: true,
},
{
name: "check fails when tls server name invalid",
getTLSConfigFn: getTLSConfigFn("g00gle.com", nil),
endpointListFn: func() ([]string, error) {
return []string{"https://google.com"}, nil
},
wantCheckExecuted: true,
wantErr: true,
},
{
name: "check fails when tls rootCAs invalid",
getTLSConfigFn: getTLSConfigFnEmptyRootCAs("google.com", nil),
endpointListFn: func() ([]string, error) {
return []string{"https://google.com"}, nil
},
wantCheckExecuted: true,
wantErr: true,
},
{
name: "endpoints lister error",
endpointListFn: func() ([]string, error) {
return nil, fmt.Errorf("some error")
},
wantErr: true,
wantCheckExecuted: false,
wantErr: true,
},
{
name: "non working endpoints",
name: "non working endpoints",
getTLSConfigFn: getTLSConfigFn("google.com", nil),
endpointListFn: func() ([]string, error) {
return []string{"https://google.com", "https://nonexistenturl.com"}, nil
},
wantErr: true,
wantCheckExecuted: true,
wantErr: true,
},
{
name: "invalid url",
name: "invalid url",
getTLSConfigFn: getTLSConfigFn("google.com", nil),
endpointListFn: func() ([]string, error) {
return []string{"htt//bad`string"}, nil
},
wantErr: true,
wantCheckExecuted: true,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := &endpointAccessibleController{
operatorClient: v1helpers.NewFakeOperatorClient(&operatorv1.OperatorSpec{}, &operatorv1.OperatorStatus{}, nil),
endpointListFn: tt.endpointListFn,
operatorClient: v1helpers.NewFakeOperatorClient(&operatorv1.OperatorSpec{}, &operatorv1.OperatorStatus{}, nil),
getTLSConfigFn: tt.getTLSConfigFn,
endpointListFn: tt.endpointListFn,
maxCheckLatency: maxCheckLatency,
lastEndpoints: tt.lastEndpoints,
lastCheckTime: tt.lastCheckTime,
lastServerName: tt.lastServerName,
lastCA: tt.lastCA,
}
prevLastCheckTime := c.lastCheckTime
if err := c.sync(context.Background(), factory.NewSyncContext(tt.name, events.NewInMemoryRecorder(tt.name))); (err != nil) != tt.wantErr {
t.Errorf("sync() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantCheckExecuted != (!prevLastCheckTime.Equal(c.lastCheckTime)) {
t.Errorf("sync() check was executed when it should have been skipped")
}
})
}
}

0 comments on commit 91360bd

Please sign in to comment.