From 46805de175d9ba68d77b906afb200273b78166bd Mon Sep 17 00:00:00 2001 From: Luke Bemish Date: Tue, 10 Dec 2024 18:33:06 -0600 Subject: [PATCH] Refactor concurrency handling code --- .../lukebemish/immaculate/ForkFormatter.java | 183 ++++++++++++------ .../lukebemish/immaculate/wrapper/Main.java | 65 ++++--- 2 files changed, 159 insertions(+), 89 deletions(-) diff --git a/src/main/java/dev/lukebemish/immaculate/ForkFormatter.java b/src/main/java/dev/lukebemish/immaculate/ForkFormatter.java index f98d813..953d3af 100644 --- a/src/main/java/dev/lukebemish/immaculate/ForkFormatter.java +++ b/src/main/java/dev/lukebemish/immaculate/ForkFormatter.java @@ -3,14 +3,13 @@ import java.io.BufferedReader; import java.io.DataInputStream; import java.io.DataOutputStream; -import java.io.EOFException; import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.io.UncheckedIOException; import java.net.InetAddress; import java.net.Socket; -import java.nio.charset.StandardCharsets; +import java.net.SocketException; import java.util.ArrayList; import java.util.List; import java.util.Map; @@ -24,7 +23,6 @@ public class ForkFormatter implements FileFormatter { private final Process process; - private final Socket socket; private final ResultListener listener; public ForkFormatter(ForkFormatterSpec spec) { @@ -35,6 +33,9 @@ public ForkFormatter(ForkFormatterSpec spec) { List args = new ArrayList<>(); args.add(spec.getJavaLauncher().get().getExecutablePath().getAsFile().toString()); args.addAll(spec.getJvmArgs().get()); + if (spec.getHideStacktrace().get()) { + args.add("-Ddev.lukebemish.immaculate.wrapper.hidestacktrace=true"); + } args.addAll(List.of( "-cp", spec.getClasspath().getAsPath(), @@ -66,9 +67,7 @@ public ForkFormatter(ForkFormatterSpec spec) { try { String socketPortString = socketPort.get(4000, TimeUnit.MILLISECONDS); int port = Integer.parseInt(socketPortString); - this.socket = new Socket(InetAddress.getLoopbackAddress(), port); - - this.listener = new ResultListener(socket); + this.listener = new ResultListener(new Socket(InetAddress.getLoopbackAddress(), port)); this.listener.start(); } catch (InterruptedException | ExecutionException | TimeoutException | IOException e) { throw new RuntimeException(e); @@ -103,49 +102,132 @@ public void run() { } } + private static final class SocketHandle { + private final DataOutputStream output; + private final DataInputStream input; + private final Socket socket; + + private SocketHandle(Socket socket) throws IOException { + this.output = new DataOutputStream(socket.getOutputStream()); + this.input = new DataInputStream(socket.getInputStream()); + this.socket = socket; + } + + synchronized void writeSubmission(int id, String fileName, String text) throws IOException { + output.writeInt(id); + output.writeUTF(fileName); + output.writeUTF(text); + output.flush(); + } + + // Will be true only if a shutdown signal is successfully sent to the channel. + private volatile boolean gracefulShutdown = false; + + synchronized void shutdown() throws IOException { + try { + // -1 ID signals the end of submissions + output.writeInt(-1); + output.flush(); + this.gracefulShutdown = true; + } finally { + // Then close the socket + socket.close(); + } + } + + int readId() throws IOException { + try { + return input.readInt(); + } catch (SocketException e) { + // Could be the socket is intentionally closed during cleanup, could be something went sideways. + // To differentiate -- check gracefulShutdown + if (gracefulShutdown) { + return -1; + } + throw e; + } + } + + boolean readSuccess() throws IOException { + return input.readBoolean(); + } + + String readResult() throws IOException { + return input.readUTF(); + } + } + private static final class ResultListener extends Thread { private final Map> results = new ConcurrentHashMap<>(); - private final Socket socket; - private final DataOutputStream output; + private final SocketHandle socketHandle; + // Handle uncaught exceptions by re-throwing them on shutdown + private volatile Throwable thrownException; private ResultListener(Socket socket) throws IOException { - this.socket = socket; - output = new DataOutputStream(socket.getOutputStream()); + this.socketHandle = new SocketHandle(socket); this.setUncaughtExceptionHandler((t, e) -> { try { shutdown(e); + thrownException = e; } catch (IOException ex) { var exception = new UncheckedIOException(ex); exception.addSuppressed(e); - ResultListener.this.getThreadGroup().uncaughtException(t, exception); + thrownException = exception; } - ResultListener.this.getThreadGroup().uncaughtException(t, e); }); } - public synchronized CompletableFuture submit(int id, String fileName, String text) throws IOException { + // Non-blocking, returns a future that will complete when the result is available (or throws if the listener is closed early unexpectedly) + public CompletableFuture submit(int id, String fileName, String text) throws IOException { if (closed.get()) { - throw new IOException("Listener is closed"); + return CompletableFuture.failedFuture(new IOException("Listener is closed")); } var out = results.computeIfAbsent(id, i -> new CompletableFuture<>()); - output.writeInt(id); - byte[] fileNameBytes = fileName.getBytes(StandardCharsets.UTF_8); - byte[] textBytes = text.getBytes(StandardCharsets.UTF_8); - output.writeInt(fileNameBytes.length); - output.write(fileNameBytes); - output.writeInt(textBytes.length); - output.write(textBytes); - output.flush(); + // Submissions to the child process take the format ID, file name, file contents -- the ID lets the result be matched up + socketHandle.writeSubmission(id, fileName, text); return out; } private final AtomicBoolean closed = new AtomicBoolean(); - public void shutdown() throws IOException { + // Blocks until proper thread shutdown + public void ensureShutdown() throws Throwable { + /* + Cleans up the child process, stops the listener thread, and joins the thread ensuring it is closed, rethrowing + exceptions as necessary. + */ shutdown(new IOException("Execution was interrupted")); + + try { + this.join(); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + + if (thrownException != null) { + throw thrownException; + } } + // Non-blocking private void shutdown(Throwable t) throws IOException { + /* + This method handles graceful shutdown of the child process and forcing the listener thread to stop. It does + not ensure the listener thread is closed. + - ensure the shutdown logic runs exactly once, in the proper order; it is possible for logic running on the + thread to request a shutdown during the shutdown process initialized from another thread. + - prevent submission of new tasks + - complete all pending tasks exceptionally (with the provided exception if one is given) + - stop the child process (by sending it a "shutdown" signal with ID -1) + - stop the thread if it is running. The thread could be waiting at a number of places. Either: + - the readId() call, if everything is running normally + - the readResult() or readSuccess() call, if something is going badly wrong in the child process + - not waiting, just in the loop -- the "closed" flag will be checked at the top of the loop + to stop in either of these cases, we simply close the socket; this results in anything blocking on reading + from the socket throwing an exception (see Socket#close()). + */ + + // Prevent multiple concurrent shutdowns if (!this.closed.compareAndSet(false, true)) return; for (var future : results.values()) { @@ -153,33 +235,23 @@ private void shutdown(Throwable t) throws IOException { } results.clear(); - socket.shutdownInput(); - - if (Thread.currentThread() != this) { - try { - this.join(); - } catch (InterruptedException e) { - // continue, it's fine - } - } - - output.writeInt(-1); - output.flush(); - socket.close(); + socketHandle.shutdown(); } @Override public void run() { try { if (!closed.get()) { - var input = new DataInputStream(socket.getInputStream()); while (!closed.get()) { - int id = input.readInt(); - boolean success = input.readBoolean(); + int id = socketHandle.readId(); + if (id == -1) { + // The child process has been sent a shutdown signal gracefully + shutdown(new IOException("Listener is closed")); + break; + } + boolean success = socketHandle.readSuccess(); if (success) { - int length = input.readInt(); - byte[] bytes = input.readNBytes(length); - String result = new String(bytes, StandardCharsets.UTF_8); + String result = socketHandle.readResult(); var future = results.remove(id); if (future != null) { future.complete(result); @@ -193,12 +265,6 @@ public void run() { } } } - } catch (EOFException e) { - try { - shutdown(e); - } catch (IOException ex) { - throw new UncheckedIOException(ex); - } } catch (IOException e) { throw new UncheckedIOException(e); } @@ -206,28 +272,21 @@ public void run() { } @Override - public synchronized void close() { - List suppressed = new ArrayList<>(); + public void close() { + List suppressed = new ArrayList<>(); if (listener != null) { try { - listener.shutdown(); - } catch (Exception e) { - suppressed.add(e); - } - } - if (socket != null) { - try { - socket.close(); - } catch (Exception e) { - suppressed.add(e); + listener.ensureShutdown(); + } catch (Throwable t) { + suppressed.add(t); } } if (process != null) { try { process.destroy(); process.waitFor(); - } catch (Exception e) { - suppressed.add(e); + } catch (Throwable t) { + suppressed.add(t); } } if (!suppressed.isEmpty()) { diff --git a/wrapper/src/main/java/dev/lukebemish/immaculate/wrapper/Main.java b/wrapper/src/main/java/dev/lukebemish/immaculate/wrapper/Main.java index 78ca290..ef7c821 100644 --- a/wrapper/src/main/java/dev/lukebemish/immaculate/wrapper/Main.java +++ b/wrapper/src/main/java/dev/lukebemish/immaculate/wrapper/Main.java @@ -5,7 +5,7 @@ import java.io.IOException; import java.lang.reflect.Constructor; import java.net.ServerSocket; -import java.nio.charset.StandardCharsets; +import java.net.Socket; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; @@ -57,29 +57,30 @@ private void run() throws IOException { // This tells the parent process what port we're listening on System.out.println(socket.getLocalPort()); var socket = this.socket.accept(); - var input = new DataInputStream(socket.getInputStream()); - var os = new DataOutputStream(socket.getOutputStream()); - var output = new Output(os); + // Communication back to the parent is done through this handle, which ensures synchronization on the output stream. + var socketHandle = new SocketHandle(socket); while (true) { - int id = input.readInt(); + int id = socketHandle.readId(); if (id == -1) { + // We have been sent a signal to gracefully shutdown, so we stop processing new submissions break; } - String fileName = new String(input.readNBytes(input.readInt()), StandardCharsets.UTF_8); - String text = new String(input.readNBytes(input.readInt()), StandardCharsets.UTF_8); - execute(id, fileName, text, output); + String fileName = socketHandle.readUTF(); + String text = socketHandle.readUTF(); + // Submissions to the child process take the format ID, file name, file contents + execute(id, fileName, text, socketHandle); } } - private void execute(int id, String fileName, String text, Output output) { + private void execute(int id, String fileName, String text, SocketHandle socketHandle) { executor.submit(() -> { try { String result = wrapper.format(fileName, text); - output.writeSuccess(id, result); + socketHandle.writeSuccess(id, result); } catch (Throwable t) { logException(t); try { - output.writeFailure(id); + socketHandle.writeFailure(id); } catch (IOException e) { throw new RuntimeException(e); } @@ -96,24 +97,34 @@ private static void logException(Throwable t) { } } - private record Output(DataOutputStream output) { - void writeFailure(int id) throws IOException { - synchronized (this) { - output.writeInt(id); - output.writeBoolean(false); - output.flush(); - } + private static final class SocketHandle { + private final DataOutputStream output; + private final DataInputStream input; + + private SocketHandle(Socket socket) throws IOException { + this.output = new DataOutputStream(socket.getOutputStream()); + this.input = new DataInputStream(socket.getInputStream()); } - void writeSuccess(int id, String result) throws IOException { - synchronized (this) { - output.writeInt(id); - output.writeBoolean(true); - byte[] bytes = result.getBytes(StandardCharsets.UTF_8); - output.writeInt(bytes.length); - output.write(bytes); - output.flush(); - } + synchronized void writeFailure(int id) throws IOException { + output.writeInt(id); + output.writeBoolean(false); + output.flush(); + } + + synchronized void writeSuccess(int id, String result) throws IOException { + output.writeInt(id); + output.writeBoolean(true); + output.writeUTF(result); + output.flush(); + } + + int readId() throws IOException { + return input.readInt(); + } + + String readUTF() throws IOException { + return input.readUTF(); } } }