Skip to content

Commit

Permalink
Cross off some minor todos before full testing
Browse files Browse the repository at this point in the history
  • Loading branch information
NHAS committed Jan 21, 2024
1 parent b04c087 commit 93c7e1c
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 25 deletions.
3 changes: 0 additions & 3 deletions internal/data/events.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,6 @@ func RegisterClusterHealthWatcher(fnc ClusterHealthFunc) {
func watchEvents() {
wc := etcd.Watch(context.Background(), "", clientv3.WithPrefix(), clientv3.WithPrevKV())
for watchEvent := range wc {
log.Println("got event: ", watchEvent)

// TODO make sure that we account for compaction events
for _, event := range watchEvent.Events {

var (
Expand Down
8 changes: 4 additions & 4 deletions internal/router/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (

var lock sync.RWMutex

func Setup(error chan<- error, iptables bool) (err error) {
func Setup(errorChan chan<- error, iptables bool) (err error) {

initialUsers, knownDevices, err := data.GetInitialData()
if err != nil {
Expand Down Expand Up @@ -47,14 +47,14 @@ func Setup(error chan<- error, iptables bool) (err error) {
return err
}

handleEvents()
handleEvents(errorChan)

go func() {
startup := true
cache := map[string]string{}
d, err := data.GetAllDevices()
if err != nil {
error <- err
errorChan <- err
return
}

Expand All @@ -66,7 +66,7 @@ func Setup(error chan<- error, iptables bool) (err error) {

dev, err := ctrl.Device(config.Values().Wireguard.DevName)
if err != nil {
error <- fmt.Errorf("endpoint watcher: %s", err)
errorChan <- fmt.Errorf("endpoint watcher: %s", err)
return
}

Expand Down
43 changes: 27 additions & 16 deletions internal/router/statemachine.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ import (
"log"

"github.com/NHAS/wag/internal/acls"
"github.com/NHAS/wag/internal/config"
"github.com/NHAS/wag/internal/data"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
)

func handleEvents() {
func handleEvents(erroChan chan<- error) {
data.RegisterAclsWatcher(aclsChanges)
data.RegisterClusterHealthWatcher(clusterState)
data.RegisterClusterHealthWatcher(clusterState(erroChan))
data.RegisterDeviceWatcher(deviceChanges)
data.RegisterGroupsWatcher(groupChanges)
data.RegisterUserWatcher(userChanges)
Expand Down Expand Up @@ -41,14 +42,23 @@ func deviceChanges(device data.BasicEvent[data.Device], state int) {
}
}

if (device.CurrentValue.Attempts != device.Previous.Attempts && device.CurrentValue.Attempts > 5) ||
if (device.CurrentValue.Attempts != device.Previous.Attempts && device.CurrentValue.Attempts > config.Values().Lockout) ||
device.CurrentValue.Endpoint.String() != device.Previous.Endpoint.String() {
err := Deauthenticate(device.CurrentValue.Address)
if err != nil {
log.Println(err)
}
}

if device.CurrentValue.Authorised != device.Previous.Authorised {
if device.CurrentValue.Attempts <= config.Values().Lockout {
err := SetAuthorized(device.CurrentValue.Address, device.CurrentValue.Username)
if err != nil {
log.Println(err)
}
}
}

default:
panic("unknown state")
}
Expand Down Expand Up @@ -117,19 +127,20 @@ func groupChanges(groupChange data.TargettedEvent[[]string], state int) {
}
}

func clusterState(stateText string, state int) {
switch stateText {
case "dead":
TearDown()
case "healthy":
errors := make(chan error)
go func() {
<-errors
// TODO fix this
}()
err := Setup(errors, true)
if err != nil {
log.Fatal(err)
func clusterState(errorsChan chan<- error) data.ClusterHealthFunc {

return func(stateText string, state int) {
switch stateText {
case "dead":
log.Println("Cluster has entered dead state, tearing down")
TearDown()
case "healthy":
err := Setup(errorsChan, true)
if err != nil {
errorsChan <- err
log.Println("was unable to return wag member to healthy state, dying: ", err)
return
}
}
}
}
2 changes: 0 additions & 2 deletions internal/users/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,6 @@ func (u *user) Authenticate(device, mfaType string, authenticator authenticators
return fmt.Errorf("%s %s unable to reset number of mfa attempts: %s", u.Username, device, err)
}

// TODO gonna have to do an additional something here in order to send the statemachine a message we need to update the ebpf

return nil
}

Expand Down

0 comments on commit 93c7e1c

Please sign in to comment.