Skip to content

Commit

Permalink
feedback
Browse files Browse the repository at this point in the history
  • Loading branch information
alekseyvdovenko committed Nov 21, 2024
1 parent 0df415d commit 9998347
Show file tree
Hide file tree
Showing 7 changed files with 65 additions and 61 deletions.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

@Data
@EqualsAndHashCode(callSuper = true)
public abstract class Deployment extends AccessControlled {
public abstract class Deployment extends RoleBasedEntity {
private String endpoint;
@JsonAlias({"displayName", "display_name"})
private String displayName;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package com.epam.aidial.core.config;

import com.fasterxml.jackson.annotation.JsonAlias;
import lombok.Data;

import java.util.List;
import java.util.Set;

@Data
public abstract class RoleBasedEntity {

private String name;

@JsonAlias({"userRoles", "user_roles"})
private Set<String> userRoles;

/**
* Checks if the actual user roles ({@code actualUserRoles} parameter) contain any of the expected user roles ({@code userRoles} field).
* The method verifies if any of the passed role is allowed to operate with the deployment, route or any other instance that extends {@link RoleBasedEntity}.
*
* @return true if one of the {@code actualUserRoles} exists in the {@code userRoles} or if {@code userRoles} field is empty,
* meaning that the current descendant of {@link RoleBasedEntity} allows access to it for any role.
*/
public boolean hasAccess(List<String> actualUserRoles) {
Set<String> expectedUserRoles = getUserRoles();

if (expectedUserRoles == null) {
return true;
}

return !expectedUserRoles.isEmpty()
&& actualUserRoles.stream().anyMatch(expectedUserRoles::contains);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

@Data
@EqualsAndHashCode(callSuper = true)
public class Route extends AccessControlled {
public class Route extends RoleBasedEntity {

private Response response;
private boolean rewritePath;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,8 @@ private void handleRequestBody(Buffer requestBody) {
handleRateLimitHit(rateLimitResult);
return Future.succeededFuture();
}
});
})
.onFailure(this::handleError);
} else {
context.getResponse().send(context.getResponseBody());
proxy.getLogStore().save(context);
Expand Down Expand Up @@ -203,6 +204,12 @@ private void handleRateLimitHit(RateLimitResult result) {
context.respond(result.status(), rateLimitError);
}

private void handleError(Throwable error) {
String route = context.getRoute().getName();
log.error("Failed to handle route {}", route, error);
context.respond(HttpStatus.INTERNAL_SERVER_ERROR, "Failed to process route request: " + route);
}

/**
* Called when proxy failed to receive request body from the client.
*/
Expand Down Expand Up @@ -274,7 +281,7 @@ private Route selectRoute() {
@SneakyThrows
private String getEndpointUri(Upstream upstream) {
URIBuilder uriBuilder = new URIBuilder(upstream.getEndpoint());
if (context.getRoute() != null && context.getRoute().isRewritePath()) {
if (context.getRoute().isRewritePath()) {
uriBuilder.setPath(context.getRequest().path());
}
return uriBuilder.toString();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package com.epam.aidial.core.server.limiter;

import com.epam.aidial.core.config.AccessControlled;
import com.epam.aidial.core.config.Limit;
import com.epam.aidial.core.config.Role;
import com.epam.aidial.core.config.RoleBasedEntity;
import com.epam.aidial.core.server.ProxyContext;
import com.epam.aidial.core.server.data.ItemLimitStats;
import com.epam.aidial.core.server.data.LimitStats;
Expand Down Expand Up @@ -55,14 +55,14 @@ public Future<Void> increase(ProxyContext context) {
}
}

public Future<RateLimitResult> limit(ProxyContext context, AccessControlled accessControlled) {
public Future<RateLimitResult> limit(ProxyContext context, RoleBasedEntity roleBasedEntity) {
try {
// skip checking limits if redis is not available
if (resourceService == null) {
return Future.succeededFuture(RateLimitResult.SUCCESS);
}
String name = accessControlled.getName();
Limit limit = getLimitByUser(context, accessControlled);
String name = roleBasedEntity.getName();
Limit limit = getLimitByUser(context, roleBasedEntity);

if (limit == null || !limit.isPositive()) {
if (limit == null) {
Expand All @@ -73,20 +73,20 @@ public Future<RateLimitResult> limit(ProxyContext context, AccessControlled acce
return Future.succeededFuture(new RateLimitResult(HttpStatus.FORBIDDEN, "Access denied"));
}

return vertx.executeBlocking(() -> checkLimit(context, limit, accessControlled), false);
return vertx.executeBlocking(() -> checkLimit(context, limit, roleBasedEntity), false);
} catch (Throwable e) {
return Future.failedFuture(e);
}
}

public Future<LimitStats> getLimitStats(AccessControlled accessControlled, ProxyContext context) {
public Future<LimitStats> getLimitStats(RoleBasedEntity roleBasedEntity, ProxyContext context) {
try {
// skip checking limits if redis is not available
if (resourceService == null) {
return Future.succeededFuture();
}
Limit limit = getLimitByUser(context, accessControlled);
return vertx.executeBlocking(() -> getLimitStats(context, limit, accessControlled.getName()), false);
Limit limit = getLimitByUser(context, roleBasedEntity);
return vertx.executeBlocking(() -> getLimitStats(context, limit, roleBasedEntity.getName()), false);
} catch (Throwable e) {
return Future.failedFuture(e);
}
Expand Down Expand Up @@ -152,17 +152,17 @@ private ResourceDescriptor getResourceDescription(ProxyContext context, String p
return ResourceDescriptorFactory.fromEncoded(ResourceTypes.LIMIT, bucketLocation, bucketLocation, path);
}

private RateLimitResult checkLimit(ProxyContext context, Limit limit, AccessControlled accessControlled) {
private RateLimitResult checkLimit(ProxyContext context, Limit limit, RoleBasedEntity roleBasedEntity) {
long timestamp = System.currentTimeMillis();
RateLimitResult tokenResult = checkTokenLimit(context, limit, timestamp, accessControlled);
RateLimitResult tokenResult = checkTokenLimit(context, limit, timestamp, roleBasedEntity);
if (tokenResult.status() != HttpStatus.OK) {
return tokenResult;
}
return checkRequestLimit(context, limit, timestamp, accessControlled);
return checkRequestLimit(context, limit, timestamp, roleBasedEntity);
}

private RateLimitResult checkTokenLimit(ProxyContext context, Limit limit, long timestamp, AccessControlled accessControlled) {
String tokensPath = getPathToTokens(accessControlled.getName());
private RateLimitResult checkTokenLimit(ProxyContext context, Limit limit, long timestamp, RoleBasedEntity roleBasedEntity) {
String tokensPath = getPathToTokens(roleBasedEntity.getName());
ResourceDescriptor resourceDescription = getResourceDescription(context, tokensPath);
String prevValue = resourceService.getResource(resourceDescription);
TokenRateLimit rateLimit = ProxyUtil.convertToObject(prevValue, TokenRateLimit.class);
Expand All @@ -172,8 +172,8 @@ private RateLimitResult checkTokenLimit(ProxyContext context, Limit limit, long
return rateLimit.update(timestamp, limit);
}

private RateLimitResult checkRequestLimit(ProxyContext context, Limit limit, long timestamp, AccessControlled accessControlled) {
String tokensPath = getPathToRequests(accessControlled.getName());
private RateLimitResult checkRequestLimit(ProxyContext context, Limit limit, long timestamp, RoleBasedEntity roleBasedEntity) {
String tokensPath = getPathToRequests(roleBasedEntity.getName());
ResourceDescriptor resourceDescription = getResourceDescription(context, tokensPath);
// pass array to hold rate limit result returned by the function to compute the resource
RateLimitResult[] result = new RateLimitResult[1];
Expand Down Expand Up @@ -205,15 +205,15 @@ private String updateTokenLimit(String json, long totalUsedTokens) {
return ProxyUtil.convertToString(rateLimit);
}

private Limit getLimitByUser(ProxyContext context, AccessControlled accessControlled) {
String name = accessControlled.getName();
private Limit getLimitByUser(ProxyContext context, RoleBasedEntity roleBasedEntity) {
String name = roleBasedEntity.getName();
List<String> userRoles;
if (accessControlled.getUserRoles() == null) {
if (roleBasedEntity.getUserRoles() == null) {
// find limits for all user roles
userRoles = context.getUserRoles();
} else {
// find limits for user roles which match to required roles
userRoles = context.getUserRoles().stream().filter(role -> accessControlled.getUserRoles().contains(role)).toList();
userRoles = context.getUserRoles().stream().filter(role -> roleBasedEntity.getUserRoles().contains(role)).toList();
}
Map<String, Role> roles = context.getConfig().getRoles();
Limit defaultUserLimit = getLimit(roles, DEFAULT_USER_ROLE, name, DEFAULT_LIMIT);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class AccessControlledTest {
public class RoleBasedEntityTest {

@Test
public void testHasAssessByRole_DeploymentRolesEmpty() {
Expand All @@ -37,14 +37,6 @@ public void testHasAssessByRole_RoleMismatch() {
assertFalse(deployment.hasAccess(Collections.emptyList()));
}

@Test
public void testHasAssessByRole_RoleIsNull() {
Deployment deployment = new Model();
deployment.setUserRoles(Set.of("role1"));

assertFalse(deployment.hasAccess(null));
}

@Test
public void testHasAssessByRole_Success() {
Deployment deployment = new Model();
Expand Down

0 comments on commit 9998347

Please sign in to comment.