diff --git a/internal/data/events.go b/internal/data/events.go index 4f447b59..080c6dd7 100644 --- a/internal/data/events.go +++ b/internal/data/events.go @@ -94,9 +94,6 @@ func RegisterClusterHealthWatcher(fnc ClusterHealthFunc) { func watchEvents() { wc := etcd.Watch(context.Background(), "", clientv3.WithPrefix(), clientv3.WithPrevKV()) for watchEvent := range wc { - log.Println("got event: ", watchEvent) - - // TODO make sure that we account for compaction events for _, event := range watchEvent.Events { var ( diff --git a/internal/router/init.go b/internal/router/init.go index 39b712b0..31398447 100644 --- a/internal/router/init.go +++ b/internal/router/init.go @@ -17,7 +17,7 @@ import ( var lock sync.RWMutex -func Setup(error chan<- error, iptables bool) (err error) { +func Setup(errorChan chan<- error, iptables bool) (err error) { initialUsers, knownDevices, err := data.GetInitialData() if err != nil { @@ -47,14 +47,14 @@ func Setup(error chan<- error, iptables bool) (err error) { return err } - handleEvents() + handleEvents(errorChan) go func() { startup := true cache := map[string]string{} d, err := data.GetAllDevices() if err != nil { - error <- err + errorChan <- err return } @@ -66,7 +66,7 @@ func Setup(error chan<- error, iptables bool) (err error) { dev, err := ctrl.Device(config.Values().Wireguard.DevName) if err != nil { - error <- fmt.Errorf("endpoint watcher: %s", err) + errorChan <- fmt.Errorf("endpoint watcher: %s", err) return } diff --git a/internal/router/statemachine.go b/internal/router/statemachine.go index 6e299608..e33ab72a 100644 --- a/internal/router/statemachine.go +++ b/internal/router/statemachine.go @@ -4,13 +4,14 @@ import ( "log" "github.com/NHAS/wag/internal/acls" + "github.com/NHAS/wag/internal/config" "github.com/NHAS/wag/internal/data" "golang.zx2c4.com/wireguard/wgctrl/wgtypes" ) -func handleEvents() { +func handleEvents(erroChan chan<- error) { data.RegisterAclsWatcher(aclsChanges) - data.RegisterClusterHealthWatcher(clusterState) + data.RegisterClusterHealthWatcher(clusterState(erroChan)) data.RegisterDeviceWatcher(deviceChanges) data.RegisterGroupsWatcher(groupChanges) data.RegisterUserWatcher(userChanges) @@ -41,7 +42,7 @@ func deviceChanges(device data.BasicEvent[data.Device], state int) { } } - if (device.CurrentValue.Attempts != device.Previous.Attempts && device.CurrentValue.Attempts > 5) || + if (device.CurrentValue.Attempts != device.Previous.Attempts && device.CurrentValue.Attempts > config.Values().Lockout) || device.CurrentValue.Endpoint.String() != device.Previous.Endpoint.String() { err := Deauthenticate(device.CurrentValue.Address) if err != nil { @@ -49,6 +50,15 @@ func deviceChanges(device data.BasicEvent[data.Device], state int) { } } + if device.CurrentValue.Authorised != device.Previous.Authorised { + if device.CurrentValue.Attempts <= config.Values().Lockout { + err := SetAuthorized(device.CurrentValue.Address, device.CurrentValue.Username) + if err != nil { + log.Println(err) + } + } + } + default: panic("unknown state") } @@ -117,19 +127,20 @@ func groupChanges(groupChange data.TargettedEvent[[]string], state int) { } } -func clusterState(stateText string, state int) { - switch stateText { - case "dead": - TearDown() - case "healthy": - errors := make(chan error) - go func() { - <-errors - // TODO fix this - }() - err := Setup(errors, true) - if err != nil { - log.Fatal(err) +func clusterState(errorsChan chan<- error) data.ClusterHealthFunc { + + return func(stateText string, state int) { + switch stateText { + case "dead": + log.Println("Cluster has entered dead state, tearing down") + TearDown() + case "healthy": + err := Setup(errorsChan, true) + if err != nil { + errorsChan <- err + log.Println("was unable to return wag member to healthy state, dying: ", err) + return + } } } } diff --git a/internal/users/user.go b/internal/users/user.go index ebeed69f..8d5fc641 100644 --- a/internal/users/user.go +++ b/internal/users/user.go @@ -150,8 +150,6 @@ func (u *user) Authenticate(device, mfaType string, authenticator authenticators return fmt.Errorf("%s %s unable to reset number of mfa attempts: %s", u.Username, device, err) } - // TODO gonna have to do an additional something here in order to send the statemachine a message we need to update the ebpf - return nil }