Skip to content

Commit

Permalink
update lambda layer to accomodate lambda updates
Browse files Browse the repository at this point in the history
  • Loading branch information
sigpwned committed Jun 21, 2024
1 parent a79db2e commit 0b9b975
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 108 deletions.
2 changes: 1 addition & 1 deletion lambda-exec-wrapper.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ env

# Unpack our tmpdump, if we have one
pushd /tmp
/var/lang/bin/java -classpath /var/runtime/lib/aws-lambda-java-core-1.2.3.jar:/var/runtime/lib/aws-lambda-java-runtime-interface-client-2.4.1-linux-x86_64.jar:/var/runtime/lib/aws-lambda-java-serialization-1.1.2.jar:/opt/humangraphics io.humangraphics.backend.lambda.TmpDump "s3://$HUMANGRAPHICS_BUCKET/tmpdump/$AWS_LAMBDA_FUNCTION_NAME.zip"
/var/lang/bin/java -classpath /var/runtime/lib/aws-lambda-java-core-1.2.3.jar:/var/runtime/lib/aws-lambda-java-runtime-interface-client-2.5.0-linux-x86_64.jar:/var/runtime/lib/aws-lambda-java-serialization-1.1.5.jar:/opt/humangraphics io.humangraphics.backend.lambda.TmpDump "s3://$HUMANGRAPHICS_BUCKET/tmpdump/$AWS_LAMBDA_FUNCTION_NAME.zip"
popd

# Grab our args
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@
<dependency>
<groupId>com.amazonaws</groupId>
<artifactId>aws-lambda-java-runtime-interface-client</artifactId>
<version>2.4.1</version>
<version>2.5.0</version>
<scope>provided</scope>
</dependency>
<!-- For testing -->
Expand Down
Original file line number Diff line number Diff line change
@@ -1,44 +1,43 @@
//
// AWSLambda.java
//
// Copyright (c) 2013 Amazon. All rights reserved.
//
/*
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. SPDX-License-Identifier:
* Apache-2.0
*/
package io.humangraphics.backend.lambda.thirdparty.com.amazonaws.services.lambda.runtime.api.client;

import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileDescriptor;
import java.io.FileInputStream;
import java.io.IOError;
import java.io.IOException;
import java.io.OutputStream;
import java.io.PrintStream;
import java.lang.reflect.Constructor;
import java.security.Security;
import java.util.Properties;
import com.amazonaws.services.lambda.crac.Core;
import com.amazonaws.services.lambda.runtime.LambdaLogger;
import com.amazonaws.services.lambda.runtime.api.client.EventHandlerLoader;
import com.amazonaws.services.lambda.runtime.api.client.Failure;
import com.amazonaws.services.lambda.runtime.api.client.HandlerInfo;
import com.amazonaws.services.lambda.runtime.api.client.LambdaEnvironment;
import com.amazonaws.services.lambda.runtime.api.client.LambdaRequestHandler;
import com.amazonaws.services.lambda.runtime.api.client.LambdaRequestHandler.UserFaultHandler;
import com.amazonaws.services.lambda.runtime.api.client.ReservedRuntimeEnvironmentVariables;
import com.amazonaws.services.lambda.runtime.api.client.UserFault;
import com.amazonaws.services.lambda.runtime.api.client.XRayErrorCause;
import com.amazonaws.services.lambda.runtime.api.client.logging.FramedTelemetryLogSink;
import com.amazonaws.services.lambda.runtime.api.client.logging.LambdaContextLogger;
import com.amazonaws.services.lambda.runtime.api.client.logging.LogSink;
import com.amazonaws.services.lambda.runtime.api.client.logging.StdOutLogSink;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.InvocationRequest;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.LambdaRuntimeClient;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.LambdaRuntimeApiClient;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.LambdaRuntimeApiClientImpl;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.converters.LambdaErrorConverter;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.converters.XRayErrorCauseConverter;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.dto.InvocationRequest;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.dto.LambdaError;
import com.amazonaws.services.lambda.runtime.api.client.runtimeapi.dto.XRayErrorCause;
import com.amazonaws.services.lambda.runtime.api.client.util.LambdaOutputStream;
import com.amazonaws.services.lambda.runtime.api.client.util.UnsafeUtil;
import com.amazonaws.services.lambda.runtime.logging.LogFormat;
import com.amazonaws.services.lambda.runtime.logging.LogLevel;
import com.amazonaws.services.lambda.runtime.serialization.PojoSerializer;
import com.amazonaws.services.lambda.runtime.serialization.factories.GsonFactory;
import com.amazonaws.services.lambda.runtime.serialization.factories.JacksonFactory;
import com.amazonaws.services.lambda.runtime.serialization.util.ReflectUtil;


Expand Down Expand Up @@ -76,6 +75,10 @@ public class AWSLambda {
private static final String AWS_LAMBDA_INITIALIZATION_TYPE =
System.getenv(ReservedRuntimeEnvironmentVariables.AWS_LAMBDA_INITIALIZATION_TYPE);

protected static ClassLoader customerClassLoader;

private static LambdaRuntimeApiClient runtimeClient;

static {
// Override the disabledAlgorithms setting to match configuration for openjdk8-u181.
// This is to keep DES ciphers around while we deploying security updates.
Expand Down Expand Up @@ -144,23 +147,20 @@ private static LambdaRequestHandler findRequestHandler(final String handlerStrin
return requestHandler;
}

/**
* aboothe 20240620 copied from {@link UserFault} to fix package visibility issue
*/
private static UserFault makeInitErrorUserFault(Throwable e, String className) {
return new UserFault(
"Error loading class " + className + (e.getMessage() == null ? "" : ": " + e.getMessage()),
e.getClass().getName(), UserFault.trace(e), true);
}

public static void setupRuntimeLogger(LambdaLogger lambdaLogger) throws ClassNotFoundException {
ReflectUtil.setStaticField(Class.forName("com.amazonaws.services.lambda.runtime.LambdaRuntime"),
"logger", true, lambdaLogger);
}

public static String getEnvOrExit(String envVariableName) {
String value = System.getenv(envVariableName);
if (value == null) {
System.err.println("Could not get environment variable " + envVariableName);
System.exit(-1);
}
return value;
}

// protected static URLClassLoader customerClassLoader;
protected static ClassLoader customerClassLoader;

/**
* convert an integer into a FileDescriptor object using reflection to access private members.
*/
Expand Down Expand Up @@ -196,7 +196,7 @@ public static void main(String[] args) {

private static void startRuntime(String handler) {
try (LogSink logSink = createLogSink()) {
LambdaLogger logger =
LambdaContextLogger logger =
new LambdaContextLogger(logSink, LogLevel.fromString(LambdaEnvironment.LAMBDA_LOG_LEVEL),
LogFormat.fromString(LambdaEnvironment.LAMBDA_LOG_FORMAT));
startRuntime(handler, logger);
Expand All @@ -205,43 +205,52 @@ private static void startRuntime(String handler) {
}
}

private static void startRuntime(String handler, LambdaLogger lambdaLogger) throws Throwable {
private static void startRuntime(String handler, LambdaContextLogger lambdaLogger)
throws Throwable {
UnsafeUtil.disableIllegalAccessWarning();

System.setOut(new PrintStream(new LambdaOutputStream(System.out), false, "UTF-8"));
System.setErr(new PrintStream(new LambdaOutputStream(System.err), false, "UTF-8"));
setupRuntimeLogger(lambdaLogger);

String runtimeApi = getEnvOrExit(ReservedRuntimeEnvironmentVariables.AWS_LAMBDA_RUNTIME_API);
LambdaRuntimeClient runtimeClient = new LambdaRuntimeClient(runtimeApi);
runtimeClient = new LambdaRuntimeApiClientImpl(LambdaEnvironment.RUNTIME_API);

String taskRoot = System.getProperty("user.dir");
String libRoot = "/opt/java";
// Make system classloader the customer classloader's parent to ensure any aws-lambda-java-core
// classes
// are loaded from the system classloader.
customerClassLoader = ClassLoader.getSystemClassLoader();
// customerClassLoader = ClassLoader.getSystemClassLoader();


// aboothe 20240620 just use the system class loader. This allows us to load code from the
// ServiceLoader, which we otherwise could not do using the custom class loader.
//
// customerClassLoader =
// new CustomerClassLoader(taskRoot, libRoot, ClassLoader.getSystemClassLoader());
//
customerClassLoader = ClassLoader.getSystemClassLoader();

Thread.currentThread().setContextClassLoader(customerClassLoader);

// Load the user's handler
LambdaRequestHandler requestHandler;
try {
requestHandler = findRequestHandler(handler, customerClassLoader);
} catch (UserFault userFault) {
lambdaLogger.log(userFault.reportableError());
reportInitError(new Failure(userFault), runtimeClient);
lambdaLogger.log(userFault.reportableError(),
lambdaLogger.getLogFormat() == LogFormat.JSON ? LogLevel.ERROR : LogLevel.UNDEFINED);
LambdaError error = LambdaErrorConverter.fromUserFault(userFault);
runtimeClient.reportInitError(error);
System.exit(1);
return;
}
if (INIT_TYPE_SNAP_START.equals(AWS_LAMBDA_INITIALIZATION_TYPE)) {
onInitComplete(runtimeClient, lambdaLogger);
onInitComplete(lambdaLogger);
}
boolean shouldExit = false;
while (!shouldExit) {
UserFault userFault = null;
InvocationRequest request = runtimeClient.waitForNextInvocation();
InvocationRequest request = runtimeClient.nextInvocation();
if (request.getXrayTraceId() != null) {
System.setProperty(LAMBDA_TRACE_HEADER_PROP, request.getXrayTraceId());
} else {
Expand All @@ -251,107 +260,57 @@ private static void startRuntime(String handler, LambdaLogger lambdaLogger) thro
ByteArrayOutputStream payload;
try {
payload = requestHandler.call(request);
runtimeClient.postInvocationResponse(request.getId(), payload.toByteArray());
runtimeClient.reportInvocationSuccess(request.getId(), payload.toByteArray());
boolean ignored = Thread.interrupted(); // clear interrupted flag in case if it was set by
// user's code
} catch (UserFault f) {
shouldExit = f.fatal;
userFault = f;
UserFault.filterStackTrace(f);
payload = new ByteArrayOutputStream(1024);
Failure failure = new Failure(f);
GsonFactory.getInstance().getSerializer(Failure.class).toJson(failure, payload);
shouldExit = f.fatal;
runtimeClient.postInvocationError(request.getId(), payload.toByteArray(),
failure.getErrorType());

LambdaError error = LambdaErrorConverter.fromUserFault(f);
runtimeClient.reportInvocationError(request.getId(), error);
} catch (Throwable t) {
shouldExit = t instanceof VirtualMachineError || t instanceof IOError;
UserFault.filterStackTrace(t);
userFault = UserFault.makeUserFault(t);
payload = new ByteArrayOutputStream(1024);
Failure failure = new Failure(t);
GsonFactory.getInstance().getSerializer(Failure.class).toJson(failure, payload);
// These two categories of errors are considered fatal.
shouldExit = Failure.isInvokeFailureFatal(t);
runtimeClient.postInvocationError(request.getId(), payload.toByteArray(),
failure.getErrorType(), serializeAsXRayJson(t));

LambdaError error = LambdaErrorConverter.fromThrowable(t);
XRayErrorCause xRayErrorCause = XRayErrorCauseConverter.fromThrowable(t);
runtimeClient.reportInvocationError(request.getId(), error, xRayErrorCause);
} finally {
if (userFault != null) {
lambdaLogger.log(userFault.reportableError());
lambdaLogger.log(userFault.reportableError(),
lambdaLogger.getLogFormat() == LogFormat.JSON ? LogLevel.ERROR : LogLevel.UNDEFINED);
}
}
}
}

static void onInitComplete(final LambdaRuntimeClient runtimeClient,
final LambdaLogger lambdaLogger) throws IOException {
static void onInitComplete(final LambdaContextLogger lambdaLogger) throws IOException {
try {
Core.getGlobalContext().beforeCheckpoint(null);
// Blocking call to RAPID /restore/next API, will return after taking snapshot.
// This will also be the 'entrypoint' when resuming from snapshots.
runtimeClient.getRestoreNext();
runtimeClient.restoreNext();
} catch (Exception e1) {
logExceptionCloudWatch(lambdaLogger, e1);
reportInitError(new Failure(e1), runtimeClient);
LambdaError error = LambdaErrorConverter.fromThrowable(e1);
runtimeClient.reportInitError(error);
System.exit(64);
}
try {
Core.getGlobalContext().afterRestore(null);
} catch (Exception restoreExc) {
logExceptionCloudWatch(lambdaLogger, restoreExc);
Failure errorPayload = new Failure(restoreExc);
reportRestoreError(errorPayload, runtimeClient);
LambdaError error = LambdaErrorConverter.fromThrowable(restoreExc);
runtimeClient.reportRestoreError(error);
System.exit(64);
}
}

private static void logExceptionCloudWatch(LambdaLogger lambdaLogger, Exception exc) {
private static void logExceptionCloudWatch(LambdaContextLogger lambdaLogger, Exception exc) {
UserFault.filterStackTrace(exc);
UserFault userFault = UserFault.makeUserFault(exc, true);
lambdaLogger.log(userFault.reportableError());
}

static void reportInitError(final Failure failure, final LambdaRuntimeClient runtimeClient)
throws IOException {

ByteArrayOutputStream payload = new ByteArrayOutputStream(1024);
JacksonFactory.getInstance().getSerializer(Failure.class).toJson(failure, payload);
runtimeClient.postInitError(payload.toByteArray(), failure.getErrorType());
}

static int reportRestoreError(final Failure failure, final LambdaRuntimeClient runtimeClient)
throws IOException {

ByteArrayOutputStream payload = new ByteArrayOutputStream(1024);
JacksonFactory.getInstance().getSerializer(Failure.class).toJson(failure, payload);
return runtimeClient.postRestoreError(payload.toByteArray(), failure.getErrorType());
}

private static PojoSerializer<XRayErrorCause> xRayErrorCauseSerializer;

/**
* @param throwable throwable to convert
* @return json as string expected by XRay's web console. On conversion failure, returns null.
*/
private static String serializeAsXRayJson(Throwable throwable) {
try {
final OutputStream outputStream = new ByteArrayOutputStream();
final XRayErrorCause cause = new XRayErrorCause(throwable);
if (xRayErrorCauseSerializer == null) {
xRayErrorCauseSerializer = JacksonFactory.getInstance().getSerializer(XRayErrorCause.class);
}
xRayErrorCauseSerializer.toJson(cause, outputStream);
return outputStream.toString();
} catch (Exception e) {
return null;
}
}

/**
* Pulled from com.amazonaws.services.lambda.runtime.api.client.UserFault
*/
static UserFault makeInitErrorUserFault(Throwable e, String className) {
return new UserFault(
"Error loading class " + className + (e.getMessage() == null ? "" : ": " + e.getMessage()),
e.getClass().getName(), UserFault.trace(e), true);
lambdaLogger.log(userFault.reportableError(),
lambdaLogger.getLogFormat() == LogFormat.JSON ? LogLevel.ERROR : LogLevel.UNDEFINED);
}
}

0 comments on commit 0b9b975

Please sign in to comment.