From d66b0600f2ec77b0cfd80b2cc634f270db24e3e5 Mon Sep 17 00:00:00 2001 From: TungHoang Date: Tue, 22 Nov 2022 14:31:00 +0100 Subject: [PATCH 1/3] Add authorization API to the client - Add AddPolicy, RevokePolicy, ListPolicy implementation - Add unit tests --- v2/client.go | 66 +++++++++++++++++++++++++++++ v2/client_test.go | 67 +++++++++++++++++++++++++++++ v2/common_test.go | 105 ++++++++++++++++++++++++++++++++++++++++++++++ v2/metadata.go | 6 +++ 4 files changed, 244 insertions(+) diff --git a/v2/client.go b/v2/client.go index 024488b..ea8af10 100644 --- a/v2/client.go +++ b/v2/client.go @@ -578,6 +578,15 @@ type Client interface { // consumes a set of streams. Liftbridge handles assigning partitions to // the group members and tracking the group's position in the streams. CreateConsumer(groupID string, opts ...ConsumerOption) (*Consumer, error) + + // AddPolicy add an ACL policy to the cluster + AddPolicy(ctx context.Context, UserId, ResourceId, Action string) error + + // RevokePolicy revokes an existing ACL policy from the cluster + RevokePolicy(ctx context.Context, UserId, ResourceId, Action string) error + + // ListPolicy retrieves all existing ACL policies from the cluster + ListPolicy(ctx context.Context) (map[int32]*ACLPolicy, error) } // client implements the Client interface. It maintains a pool of connections @@ -1578,6 +1587,63 @@ func (c *client) FetchCursor(ctx context.Context, id, stream string, partition i return offset, err } +// AddPolicy adds an ACL style authorization policy to the cluster. +// The client must either be `root` or have `AddPolicy` permission to execute this operation +func (c *client) AddPolicy(ctx context.Context, userId, resourceId, action string) error { + var ( + req = &proto.AddPolicyRequest{} + ) + + err := c.doResilientRPC(ctx, func(client proto.APIClient) error { + _, err := client.AddPolicy(ctx, req) + if err != nil { + return err + } + return nil + }) + return err +} + +// RevokePolicy removes an existing authorization policy from the cluster. +// The client must either be `root` or have `RevokePolicy` permission to execute this operation +func (c *client) RevokePolicy(ctx context.Context, userId, resourceId, action string) error { + var ( + req = &proto.RevokePolicyRequest{} + ) + + err := c.doResilientRPC(ctx, func(client proto.APIClient) error { + _, err := client.RevokePolicy(ctx, req) + if err != nil { + return err + } + return nil + }) + return err +} + +// ListPolicy lists all existing ACL authorization policies from the cluster. +// The client must either be `root` or have `ListPolicy` permission to execute this operation +func (c *client) ListPolicy(ctx context.Context) (map[int32]*ACLPolicy, error) { + var ( + req = &proto.ListPolicyRequest{} + policies = make(map[int32]*ACLPolicy) + ) + + err := c.doResilientRPC(ctx, func(client proto.APIClient) error { + resp, err := client.ListPolicy(ctx, req) + if err != nil { + return err + } + for i, policy := range resp.Policies { + + policies[int32(i)] = &ACLPolicy{ + UserId: policy.UserId, ResourceId: policy.ResourceId, Action: policy.Action} + } + return nil + }) + return policies, err +} + func (c *client) getCursorKey(cursorID, streamName string, partitionID int32) []byte { return []byte(fmt.Sprintf("%s,%s,%d", cursorID, streamName, partitionID)) } diff --git a/v2/client_test.go b/v2/client_test.go index ecb113a..8a7980c 100644 --- a/v2/client_test.go +++ b/v2/client_test.go @@ -1911,6 +1911,73 @@ func TestConnectToServerBasedOnWorkLoad(t *testing.T) { require.Len(t, metadata.Brokers(), 2) } +func TestAddPolicy(t *testing.T) { + server := newMockServer() + defer server.Stop(t) + port := server.Start(t) + + server.SetupMockFetchMetadataResponse(new(proto.FetchMetadataResponse)) + + server.SetupMockAddPolicyResponse(new(proto.AddPolicyResponse)) + + client, err := Connect([]string{fmt.Sprintf("localhost:%d", port)}, + AckWaitTime(time.Nanosecond)) + require.NoError(t, err) + defer client.Close() + + err = client.AddPolicy(context.Background(), "a", "b", "read") + require.NoError(t, err) +} + +func TestRevokePolicy(t *testing.T) { + server := newMockServer() + defer server.Stop(t) + port := server.Start(t) + + server.SetupMockFetchMetadataResponse(new(proto.FetchMetadataResponse)) + + server.SetupMockRevokePolicyResponse(new(proto.RevokePolicyResponse)) + + client, err := Connect([]string{fmt.Sprintf("localhost:%d", port)}, + AckWaitTime(time.Nanosecond)) + require.NoError(t, err) + defer client.Close() + + err = client.RevokePolicy(context.Background(), "a", "b", "read") + require.NoError(t, err) +} + +func TestListPolicy(t *testing.T) { + server := newMockServer() + defer server.Stop(t) + port := server.Start(t) + + mockPolicies := make(map[int32]*proto.ACLPolicy) + mockPolicies[0] = &proto.ACLPolicy{UserId: "a", ResourceId: "b", Action: "read"} + mockPolicies[1] = &proto.ACLPolicy{UserId: "c", ResourceId: "d", Action: "write"} + + server.SetupMockFetchMetadataResponse(new(proto.FetchMetadataResponse)) + + server.SetupMockListPolicyResponse(&proto.ListPolicyResponse{Policies: mockPolicies}) + + client, err := Connect([]string{fmt.Sprintf("localhost:%d", port)}, + AckWaitTime(time.Nanosecond)) + require.NoError(t, err) + defer client.Close() + + policies, err := client.ListPolicy(context.Background()) + require.NoError(t, err) + + expectedPolicies := make(map[int32]*ACLPolicy) + for i, policy := range mockPolicies { + + expectedPolicies[int32(i)] = &ACLPolicy{ + UserId: policy.UserId, ResourceId: policy.ResourceId, Action: policy.Action} + } + + require.Equal(t, expectedPolicies, policies) +} + func ExampleConnect() { addr := "localhost:9292" client, err := Connect([]string{addr}) diff --git a/v2/common_test.go b/v2/common_test.go index ea109a8..06795f6 100644 --- a/v2/common_test.go +++ b/v2/common_test.go @@ -123,6 +123,9 @@ type mockAPI struct { joinConsumerGroupRequests []*proto.JoinConsumerGroupRequest leaveConsumerGroupRequests []*proto.LeaveConsumerGroupRequest fetchConsumerGroupAssignmentsRequests []*proto.FetchConsumerGroupAssignmentsRequest + addPolicyRequests []*proto.AddPolicyRequest + revokePolicyRequests []*proto.RevokePolicyRequest + listPolicyRequests []*proto.ListPolicyRequest responses map[string]interface{} messages []*proto.Message createStreamErr error @@ -141,6 +144,9 @@ type mockAPI struct { leaveConsumerGroupErr error fetchConsumerGroupAssignmentsErr error fetchPartitionMetadataErr error + addPolicyErr error + revokePolicyError error + listPolicyError error // autclearError indicates where the mock API shall clear mock error automatically autoClearError bool // delayMetadataResponse indicates the FetchMetadata call shall be delayed for few seconds @@ -164,6 +170,9 @@ func newMockAPI() *mockAPI { joinConsumerGroupRequests: []*proto.JoinConsumerGroupRequest{}, leaveConsumerGroupRequests: []*proto.LeaveConsumerGroupRequest{}, fetchConsumerGroupAssignmentsRequests: []*proto.FetchConsumerGroupAssignmentsRequest{}, + addPolicyRequests: []*proto.AddPolicyRequest{}, + revokePolicyRequests: []*proto.RevokePolicyRequest{}, + listPolicyRequests: []*proto.ListPolicyRequest{}, responses: make(map[string]interface{}), autoClearError: false, } @@ -242,6 +251,18 @@ func (m *mockAPI) SetupMockReportConsumerGroupCoordinatorResponse(responses inte m.responses["ReportConsumerGroupCoordinator"] = responses } +func (m *mockAPI) SetupMockAddPolicyResponse(responses interface{}) { + m.responses["AddPolicy"] = responses +} + +func (m *mockAPI) SetupMockRevokePolicyResponse(responses interface{}) { + m.responses["RevokePolicy"] = responses +} + +func (m *mockAPI) SetupMockListPolicyResponse(responses interface{}) { + m.responses["ListPolicy"] = responses +} + func (m *mockAPI) SetupMockCreateStreamError(err error) { m.mu.Lock() defer m.mu.Unlock() @@ -325,6 +346,24 @@ func (m *mockAPI) SetupMockFetchCursorError(err error) { m.fetchCursorErr = err } +func (m *mockAPI) SetupMockAddPolicyError(err error) { + m.mu.Lock() + defer m.mu.Unlock() + m.addPolicyErr = err +} + +func (m *mockAPI) SetupMockRevokePolicyError(err error) { + m.mu.Lock() + defer m.mu.Unlock() + m.revokePolicyError = err +} + +func (m *mockAPI) SetupMockListPolicyError(err error) { + m.mu.Lock() + defer m.mu.Unlock() + m.listPolicyError = err +} + func (m *mockAPI) GetCreateStreamRequests() []*proto.CreateStreamRequest { m.mu.Lock() defer m.mu.Unlock() @@ -397,6 +436,24 @@ func (m *mockAPI) GetLeaveConsumerGroupRequests() []*proto.LeaveConsumerGroupReq return m.leaveConsumerGroupRequests } +func (m *mockAPI) GetAddPolicyRequests() []*proto.AddPolicyRequest { + m.mu.Lock() + defer m.mu.Unlock() + return m.addPolicyRequests +} + +func (m *mockAPI) GetRevokePolicyRequests() []*proto.RevokePolicyRequest { + m.mu.Lock() + defer m.mu.Unlock() + return m.revokePolicyRequests +} + +func (m *mockAPI) GetLisPolicyRequests() []*proto.ListPolicyRequest { + m.mu.Lock() + defer m.mu.Unlock() + return m.listPolicyRequests +} + func (m *mockAPI) CreateStream(ctx context.Context, in *proto.CreateStreamRequest) (*proto.CreateStreamResponse, error) { m.mu.Lock() defer m.mu.Unlock() @@ -666,3 +723,51 @@ func (m *mockAPI) ReportConsumerGroupCoordinator(ctx context.Context, return nil, errors.New("todo") } + +func (m *mockAPI) AddPolicy(ctx context.Context, in *proto.AddPolicyRequest) (*proto.AddPolicyResponse, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.addPolicyRequests = append(m.addPolicyRequests, in) + if m.addPolicyErr != nil { + err := m.addPolicyErr + if m.autoClearError { + m.addPolicyErr = nil + } + return nil, err + } + resp := m.responses["AddPolicy"] + return resp.(*proto.AddPolicyResponse), nil +} + +func (m *mockAPI) RevokePolicy(ctx context.Context, in *proto.RevokePolicyRequest) (*proto.RevokePolicyResponse, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.revokePolicyRequests = append(m.revokePolicyRequests, in) + if m.revokePolicyError != nil { + err := m.revokePolicyError + if m.autoClearError { + m.revokePolicyError = nil + } + return nil, err + } + resp := m.responses["RevokePolicy"] + return resp.(*proto.RevokePolicyResponse), nil +} + +func (m *mockAPI) ListPolicy(ctx context.Context, in *proto.ListPolicyRequest) (*proto.ListPolicyResponse, error) { + m.mu.Lock() + defer m.mu.Unlock() + + m.listPolicyRequests = append(m.listPolicyRequests, in) + if m.listPolicyError != nil { + err := m.listPolicyError + if m.autoClearError { + m.listPolicyError = nil + } + return nil, err + } + resp := m.responses["ListPolicy"] + return resp.(*proto.ListPolicyResponse), nil +} diff --git a/v2/metadata.go b/v2/metadata.go index b86b92c..d4b3a3d 100644 --- a/v2/metadata.go +++ b/v2/metadata.go @@ -442,3 +442,9 @@ func (m *metadataCache) get() *Metadata { defer m.mu.RUnlock() return m.metadata } + +type ACLPolicy struct { + UserId string + ResourceId string + Action string +} From 082c98298d3698f3b56872342ed23f55e18cd6c7 Mon Sep 17 00:00:00 2001 From: TungHoang Date: Wed, 23 Nov 2022 16:14:44 +0100 Subject: [PATCH 2/3] Fix typo in AddPolicy docstring --- v2/client.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/v2/client.go b/v2/client.go index ea8af10..c43f2dd 100644 --- a/v2/client.go +++ b/v2/client.go @@ -579,7 +579,7 @@ type Client interface { // the group members and tracking the group's position in the streams. CreateConsumer(groupID string, opts ...ConsumerOption) (*Consumer, error) - // AddPolicy add an ACL policy to the cluster + // AddPolicy adds an ACL policy to the cluster AddPolicy(ctx context.Context, UserId, ResourceId, Action string) error // RevokePolicy revokes an existing ACL policy from the cluster @@ -1634,11 +1634,12 @@ func (c *client) ListPolicy(ctx context.Context) (map[int32]*ACLPolicy, error) { if err != nil { return err } - for i, policy := range resp.Policies { + for i, policy := range resp.Policies { policies[int32(i)] = &ACLPolicy{ UserId: policy.UserId, ResourceId: policy.ResourceId, Action: policy.Action} } + return nil }) return policies, err From e6aa128cd637713a06c4e8907b8efba659ead1b0 Mon Sep 17 00:00:00 2001 From: TungHoang Date: Sat, 26 Nov 2022 16:20:21 +0100 Subject: [PATCH 3/3] Fix AddPolicyRequest and RevokePolicyRequest - Add User, Resource, Action in policy request --- v2/client.go | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/v2/client.go b/v2/client.go index c43f2dd..6c27d98 100644 --- a/v2/client.go +++ b/v2/client.go @@ -1591,7 +1591,10 @@ func (c *client) FetchCursor(ctx context.Context, id, stream string, partition i // The client must either be `root` or have `AddPolicy` permission to execute this operation func (c *client) AddPolicy(ctx context.Context, userId, resourceId, action string) error { var ( - req = &proto.AddPolicyRequest{} + req = &proto.AddPolicyRequest{Policy: &proto.ACLPolicy{ + UserId: userId, + ResourceId: resourceId, + Action: action}} ) err := c.doResilientRPC(ctx, func(client proto.APIClient) error { @@ -1608,7 +1611,10 @@ func (c *client) AddPolicy(ctx context.Context, userId, resourceId, action strin // The client must either be `root` or have `RevokePolicy` permission to execute this operation func (c *client) RevokePolicy(ctx context.Context, userId, resourceId, action string) error { var ( - req = &proto.RevokePolicyRequest{} + req = &proto.RevokePolicyRequest{Policy: &proto.ACLPolicy{ + UserId: userId, + ResourceId: resourceId, + Action: action}} ) err := c.doResilientRPC(ctx, func(client proto.APIClient) error {