Skip to content

Commit

Permalink
Handle context in plugins (#361)
Browse files Browse the repository at this point in the history
  • Loading branch information
balanza authored Jan 23, 2025
1 parent 732771b commit 83b891f
Show file tree
Hide file tree
Showing 7 changed files with 270 additions and 18 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ agent:

.PHONY: build-plugin-examples
build-plugin-examples:
$(GO_BUILD) -o $(BUILD_DIR)/$(CURRENT_ARCH)/plugin_examples/dummy ./plugin_examples/dummy.go
$(GO_BUILD) -o $(BUILD_DIR)/$(CURRENT_ARCH)/plugin_examples/dummy ./plugin_examples/dummy/dummy.go
$(GO_BUILD) -o $(BUILD_DIR)/$(CURRENT_ARCH)/plugin_examples/sleep ./plugin_examples/sleep/sleep.go

.PHONY: cross-compiled $(ARCHS)
cross-compiled: $(ARCHS)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package gatherers

import (
"context"
"os"
"os/exec"

"github.com/hashicorp/go-hclog"
"github.com/pkg/errors"

goplugin "github.com/hashicorp/go-plugin"
Expand Down Expand Up @@ -33,7 +33,8 @@ func (l *RPCPluginLoader) Load(pluginPath string) (FactGatherer, error) {
AllowedProtocols: []goplugin.Protocol{
goplugin.ProtocolNetRPC,
},
Logger: hclog.Default(),
SyncStdout: os.Stdout,
SyncStderr: os.Stderr,
})

rpcClient, err := client.Client()
Expand All @@ -47,22 +48,22 @@ func (l *RPCPluginLoader) Load(pluginPath string) (FactGatherer, error) {
return nil, errors.Wrap(err, "Error dispensing plugin")
}

g, ok := raw.(plugininterface.Gatherer)
pluginClient, ok := raw.(plugininterface.GathererRPC)
if !ok {
return nil, errors.Wrap(err, "Error asserting Gatherer type")
}

p := &PluggedGatherer{
plugin: g,
pluginClient: pluginClient,
}

return p, nil
}

type PluggedGatherer struct {
plugin plugininterface.Gatherer
pluginClient plugininterface.GathererRPC
}

func (g *PluggedGatherer) Gather(_ context.Context, factsRequests []entities.FactRequest) ([]entities.Fact, error) {
return g.plugin.Gather(factsRequests)
func (g *PluggedGatherer) Gather(ctx context.Context, factsRequests []entities.FactRequest) ([]entities.Fact, error) {
return g.pluginClient.RequestGathering(ctx, factsRequests)
}
5 changes: 3 additions & 2 deletions pkg/factsengine/plugininterface/interface.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package plugininterface

import (
"context"
"encoding/gob"
"net/rpc"

Expand All @@ -21,7 +22,7 @@ func init() {

// Gatherer is the interface exposed as a plugin.
type Gatherer interface {
Gather(factsRequests []entities.FactRequest) ([]entities.Fact, error)
Gather(context context.Context, factsRequests []entities.FactRequest) ([]entities.Fact, error)
}

// This is the implementation of plugin.Plugin
Expand All @@ -35,5 +36,5 @@ func (p *GathererPlugin) Server(*plugin.MuxBroker) (interface{}, error) {
}

func (GathererPlugin) Client(_ *plugin.MuxBroker, c *rpc.Client) (interface{}, error) {
return &GathererRPC{client: c}, nil
return GathererRPC{client: c}, nil
}
65 changes: 59 additions & 6 deletions pkg/factsengine/plugininterface/rpc.go
Original file line number Diff line number Diff line change
@@ -1,27 +1,80 @@
package plugininterface

import (
"context"
"net/rpc"

log "github.com/sirupsen/logrus"

"github.com/google/uuid"
"github.com/trento-project/agent/pkg/factsengine/entities"
)

type GathererRPC struct{ client *rpc.Client }

func (g *GathererRPC) Gather(factsRequest []entities.FactRequest) ([]entities.Fact, error) {
func (g *GathererRPC) RequestGathering(
ctx context.Context,
factsRequest []entities.FactRequest,
) ([]entities.Fact, error) {
var resp []entities.Fact
var err error

requestID := uuid.New().String()
args := GatheringArgs{
FactRequests: factsRequest,
RequestID: requestID,
}

err := g.client.Call("Plugin.Gather", factsRequest, &resp)
gathering := make(chan error)

return resp, err
go func() {
gathering <- g.client.Call("Plugin.ServeGathering", args, &resp)
}()

select {
case <-ctx.Done():
err = g.client.Call("Plugin.Cancel", requestID, &resp)
return []entities.Fact{}, err
case err = <-gathering:
if err != nil {
return nil, err
}
return resp, nil
}
}

type GathererRPCServer struct {
Impl Gatherer
Impl Gatherer
cancelMap map[string]context.CancelFunc
}

type GatheringArgs struct {
FactRequests []entities.FactRequest
RequestID string
}

func (s *GathererRPCServer) Gather(args []entities.FactRequest, resp *[]entities.Fact) error {
func (s *GathererRPCServer) ServeGathering(args GatheringArgs, resp *[]entities.Fact) error {

ctx, cancel := context.WithCancel(context.Background())
if s.cancelMap == nil {
s.cancelMap = make(map[string]context.CancelFunc)
}
s.cancelMap[args.RequestID] = cancel
defer delete(s.cancelMap, args.RequestID)

var err error
*resp, err = s.Impl.Gather(args)
*resp, err = s.Impl.Gather(ctx, args.FactRequests)
return err
}

func (s *GathererRPCServer) Cancel(requestID string, _ *[]entities.Fact) (_ error) {
cancel, ok := s.cancelMap[requestID]
if ok {
cancel()
delete(s.cancelMap, requestID)
} else {
log.Warnf("Cannot find cancel function for request %s", requestID)
}

return nil
}
5 changes: 3 additions & 2 deletions plugin_examples/dummy.go → plugin_examples/dummy/dummy.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package main

// go build -o /usr/etc/trento/dummy ./plugin_examples/dummy.go
// go build -o /usr/etc/trento/dummy ./plugin_examples/dummy/dummy.go

import (
"context"
"fmt"
"math/rand"

Expand All @@ -15,7 +16,7 @@ import (
type dummyGatherer struct {
}

func (s dummyGatherer) Gather(factsRequests []entities.FactRequest) ([]entities.Fact, error) {
func (s dummyGatherer) Gather(_ context.Context, factsRequests []entities.FactRequest) ([]entities.Fact, error) {
facts := []entities.Fact{}
log.Infof("Starting dummy plugin facts gathering process")

Expand Down
68 changes: 68 additions & 0 deletions plugin_examples/sleep/sleep.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package main

// go build -o /usr/etc/trento/sleep ./plugin_examples/sleep/sleep.go

import (
"context"
"fmt"
"os/exec"
"sync"

"github.com/hashicorp/go-plugin"
log "github.com/sirupsen/logrus"
"github.com/trento-project/agent/pkg/factsengine/entities"
"github.com/trento-project/agent/pkg/factsengine/plugininterface"
)

type sleepGatherer struct {
}

func (s sleepGatherer) Gather(ctx context.Context, factsRequests []entities.FactRequest) ([]entities.Fact, error) {
facts := []entities.Fact{}

log.Infof("Starting sleep plugin facts gathering process")

wg := sync.WaitGroup{}

for _, factReq := range factsRequests {
log.Infof("Sleeping for %s", factReq.Argument)
fact := entities.NewFactGatheredWithRequest(factReq, &entities.FactValueString{Value: fmt.Sprint(factReq.Argument)})
facts = append(facts, fact)

time := fmt.Sprint(factReq.Argument)
wg.Add(1)
go func(time string) {
defer wg.Done()
cmd := exec.CommandContext(ctx, "sleep", time)
err := cmd.Run()
if err != nil {
log.Errorf("Error running sleep command: %s", err)
}
}(time)

}

wg.Wait()

log.Infof("Requested sleep plugin facts gathered")
return facts, nil
}

func main() {
d := &sleepGatherer{}

handshakeConfig := plugin.HandshakeConfig{
ProtocolVersion: 1,
MagicCookieKey: "TRENTO_PLUGIN",
MagicCookieValue: "gatherer",
}

var pluginMap = map[string]plugin.Plugin{
"gatherer": &plugininterface.GathererPlugin{Impl: d},
}

plugin.Serve(&plugin.ServeConfig{ // nolint
HandshakeConfig: handshakeConfig,
Plugins: pluginMap,
})
}
127 changes: 127 additions & 0 deletions test/cli/plugins.bats
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ setup() {
PATH="$BUILD_DIR:$PATH"
}

teardown() {
# kill all the processes started by the test
pkill -P $$ || true
}

@test "it should include the dummy plugin into list" {
run trento-agent facts list --plugins-folder $BUILD_DIR/plugin_examples
Expand All @@ -32,3 +36,126 @@ setup() {
[ "$status" -eq 0 ]
echo $output | grep -q "Name: 2"
}

@test "it should remove all the processes on complete" {
# declare the expected processes
cmd_agent="trento-agent facts gather --plugins-folder $BUILD_DIR/plugin_examples --gatherer sleep --argument 2s"
cmd_plugin="$BUILD_DIR/plugin_examples/sleep"
cmd_sleep="sleep 2s"

# start the agent in background
eval "$cmd_agent &"
pid=$!

# retrieve the pid of the exepcted process
pid_agent=$(pgrep -f "$cmd_agent")
pid_plugin=$(pgrep -f "$cmd_plugin")
pid_sleep=$(pgrep -f "$cmd_sleep")

# double check the test is correct
[ $pid -eq $pid_agent ]

# ensure no duplicated processes are running
assert_one "$pid_agent"
assert_one "$pid_plugin"
assert_one "$pid_sleep"

# ensure the process tree is correct
assert_parent "$pid_plugin" "$pid_agent"
assert_parent "$pid_sleep" "$pid_plugin"

# wait for the process to finish
while kill -0 $pid 2>/dev/null; do
sleep 1
done

# test processes are killed
assert_no_pid "$pid_agent"
assert_no_pid "$pid_plugin"
assert_no_pid "$pid_sleep"
}

@test "it should remove all the processes on agent process stopped (SIGINT)" {
# declare the expected processes
cmd_agent="trento-agent facts gather --plugins-folder $BUILD_DIR/plugin_examples --gatherer sleep --argument 2s"
cmd_plugin="$BUILD_DIR/plugin_examples/sleep"
cmd_sleep="sleep 2s"

# start the agent in background
eval "$cmd_agent &"
pid=$!

# retrieve the pid of the exepcted process
pid_agent=$(pgrep -f "$cmd_agent")
pid_plugin=$(pgrep -f "$cmd_plugin")
pid_sleep=$(pgrep -f "$cmd_sleep")

# double check the test is correct
[ $pid -eq $pid_agent ]

# ensure no duplicated processes are running
assert_one "$pid_agent"
assert_one "$pid_plugin"
assert_one "$pid_sleep"

# ensure the process tree is correct
assert_parent "$pid_plugin" "$pid_agent"
assert_parent "$pid_sleep" "$pid_plugin"

# kill the agent
kill -INT $pid_agent

# test processes are killed
assert_no_pid "$pid_agent"
assert_no_pid "$pid_plugin"
assert_no_pid "$pid_sleep"
}

@test "it should remove all the processes on agent process stopped (SIGTERM)" {
# declare the expected processes
cmd_agent="trento-agent facts gather --plugins-folder $BUILD_DIR/plugin_examples --gatherer sleep --argument 2s"
cmd_plugin="$BUILD_DIR/plugin_examples/sleep"
cmd_sleep="sleep 2s"

# start the agent in background
eval "$cmd_agent &"
pid=$!

# retrieve the pid of the exepcted process
pid_agent=$(pgrep -f "$cmd_agent")
pid_plugin=$(pgrep -f "$cmd_plugin")
pid_sleep=$(pgrep -f "$cmd_sleep")

# double check the test is correct
[ $pid -eq $pid_agent ]

# ensure no duplicated processes are running
assert_one "$pid_agent"
assert_one "$pid_plugin"
assert_one "$pid_sleep"

# ensure the process tree is correct
assert_parent "$pid_plugin" "$pid_agent"
assert_parent "$pid_sleep" "$pid_plugin"

# kill the agent
kill -TERM $pid_agent

# test processes are killed
assert_no_pid "$pid_agent"
assert_no_pid "$pid_plugin"
assert_no_pid "$pid_sleep"
}


function assert_one {
[ $(echo "$1" | wc -l) == 1 ]
}

function assert_parent {
[ $(ps -o ppid= -p "$1") == "$2" ]
}

function assert_no_pid {
[ $(ps -p "$1" | wc -l) == 1 ]
}

0 comments on commit 83b891f

Please sign in to comment.