From 928d3216a53e44c20abce09f4aafdea4a2e30e44 Mon Sep 17 00:00:00 2001 From: Johnathan Gilday Date: Wed, 26 Jun 2024 16:18:24 -0400 Subject: [PATCH] =?UTF-8?q?=E2=AC=86=EF=B8=8F=20Upgrade=20GPT=203.5=20Turb?= =?UTF-8?q?o=20Model=20(#398)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit gpt-3.5-turbo-0613 has been removed, so replace it with new models. * Updated failed authentication logging codemod to use new models and `SarifToLLMForMultiOutcomeCodemod`. * Added a `Model` type to describe GPT models and colocates model-specific logic such as token counting. /close #work --- .../codemods/LogFailedLoginCodemod.java | 94 ++-- .../LogFailedLoginCodemod/fix_prompt.txt | 4 - .../LogFailedLoginCodemod/threat_prompt.txt | 18 - .../codemods/LogFailedLoginCodemodTest.java | 35 +- .../vulnerable/MainFrame.java.after | 5 +- .../SaltedHashLoginModule.java.after | 326 ++++++++++++ .../SaltedHashLoginModule.java.before | 0 .../src/test/resources/logback-test.xml | 9 + .../testutils/CodemodTestMixin.java | 9 +- .../io/codemodder/testutils/Metadata.java | 15 + .../llm/BinaryThreatAnalysisAndFix.java | 15 + .../CodeChangingLLMRemediationOutcome.java | 19 + .../plugins/llm/LLMRemediationOutcome.java | 17 + .../java/io/codemodder/plugins/llm/Model.java | 30 ++ .../llm/NoActionLLMRemediationOutcome.java | 23 + ...ForBinaryVerificationAndFixingCodemod.java | 81 ++- .../llm/SarifToLLMForMultiOutcomeCodemod.java | 484 ++++++++++++++++++ .../codemodder/plugins/llm/StandardModel.java | 49 ++ .../io/codemodder/plugins/llm/Tokens.java | 16 +- .../test/LLMVerifyingCodemodTestMixin.java | 11 +- 20 files changed, 1163 insertions(+), 97 deletions(-) delete mode 100644 core-codemods/src/main/resources/io/codemodder/codemods/LogFailedLoginCodemod/fix_prompt.txt delete mode 100644 core-codemods/src/main/resources/io/codemodder/codemods/LogFailedLoginCodemod/threat_prompt.txt create mode 100644 core-codemods/src/test/resources/log-failed-login/vulnerable/SaltedHashLoginModule.java.after rename core-codemods/src/test/resources/log-failed-login/{safe => vulnerable}/SaltedHashLoginModule.java.before (100%) create mode 100644 core-codemods/src/test/resources/logback-test.xml create mode 100644 plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/CodeChangingLLMRemediationOutcome.java create mode 100644 plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/LLMRemediationOutcome.java create mode 100644 plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/Model.java create mode 100644 plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/NoActionLLMRemediationOutcome.java create mode 100644 plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/SarifToLLMForMultiOutcomeCodemod.java create mode 100644 plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/StandardModel.java diff --git a/core-codemods/src/main/java/io/codemodder/codemods/LogFailedLoginCodemod.java b/core-codemods/src/main/java/io/codemodder/codemods/LogFailedLoginCodemod.java index 9ec8bbe2b..7e38109e8 100644 --- a/core-codemods/src/main/java/io/codemodder/codemods/LogFailedLoginCodemod.java +++ b/core-codemods/src/main/java/io/codemodder/codemods/LogFailedLoginCodemod.java @@ -1,14 +1,11 @@ package io.codemodder.codemods; -import static io.codemodder.CodemodResources.getClassResourceAsString; - -import com.contrastsecurity.sarif.Result; -import com.github.difflib.patch.AbstractDelta; -import com.github.difflib.patch.InsertDelta; -import com.github.difflib.patch.Patch; import io.codemodder.*; +import io.codemodder.plugins.llm.CodeChangingLLMRemediationOutcome; +import io.codemodder.plugins.llm.NoActionLLMRemediationOutcome; import io.codemodder.plugins.llm.OpenAIService; -import io.codemodder.plugins.llm.SarifToLLMForBinaryVerificationAndFixingCodemod; +import io.codemodder.plugins.llm.SarifToLLMForMultiOutcomeCodemod; +import io.codemodder.plugins.llm.StandardModel; import io.codemodder.providers.sarif.semgrep.SemgrepScan; import java.util.List; import javax.inject.Inject; @@ -17,39 +14,68 @@ id = "pixee:java/log-failed-login", importance = Importance.HIGH, reviewGuidance = ReviewGuidance.MERGE_AFTER_REVIEW) -public final class LogFailedLoginCodemod extends SarifToLLMForBinaryVerificationAndFixingCodemod { +public final class LogFailedLoginCodemod extends SarifToLLMForMultiOutcomeCodemod { @Inject public LogFailedLoginCodemod( @SemgrepScan(ruleId = "log-failed-login") final RuleSarif sarif, final OpenAIService openAI) { - super(sarif, openAI); + super( + sarif, + openAI, + List.of( + new NoActionLLMRemediationOutcome( + "logs_failed_login_with_logger", + """ + The code uses a logger to log a message that indicates a failed login attempt. + The message is logged at the INFO or higher level. + """ + .replace('\n', ' ')), + new NoActionLLMRemediationOutcome( + "logs_failed_login_with_console", + """ + The code sends a message to the console that indicates a failed login attempt. + The code may output this message to either System.out or System.err. + """ + .replace('\n', ' ')), + new NoActionLLMRemediationOutcome( + "throws_exception", + """ + The code throws an exception that indicates a failed login attempt. + Throwing such an exception is a reasonable alternative to logging the failed login attempt. + When the username for the failed login is in-scope, the exception message includes the username. + """ + .replace('\n', ' ')), + new NoActionLLMRemediationOutcome( + "no_authentication", + """ + The login validation fails because the request lacks credentials to validate. This is not considered a failed login attempt that requires auditing. + """ + .replace('\n', ' ')), + new CodeChangingLLMRemediationOutcome( + "add_missing_logging", + """ + None of the other outcomes apply. + The code that validates the login credentials does not log a message when the login attempt fails, + NOR does it throw an exception that reasonably indicates a failed login attempt and includes the username in the exception message. + """ + .replace('\n', ' '), + """ + Immediately following the login failure, add precisely one statement to log the failed login attempt at the INFO level. + If the username for the failed login is in scope, the new log message references the username. + Add exactly one such log statement! Exactly one! + The new log statement is consistent with the rest of the code with respect to formatting, braces, casing, etc. + When no logger is in scope, the new code emits a log message to the console. + """ + .replace('\n', ' '))), + StandardModel.GPT_4O, + StandardModel.GPT_4); } @Override - protected String getThreatPrompt( - final CodemodInvocationContext context, final List results) { - return getClassResourceAsString(getClass(), "threat_prompt.txt"); - } - - @Override - protected String getFixPrompt() { - return getClassResourceAsString(getClass(), "fix_prompt.txt"); - } - - @Override - protected boolean isPatchExpected(Patch patch) { - // This codemod should make two or fewer modifications. - if (patch.getDeltas().size() > 2) { - return false; - } - - // This codemod should only insert lines. - for (AbstractDelta delta : patch.getDeltas()) { - if (!(delta instanceof InsertDelta)) { - return false; - } - } - - return true; + protected String getThreatPrompt() { + return """ + The tool has cited an authentication check that does not include a means for auditing failed login attempt. + """ + .replace('\n', ' '); } } diff --git a/core-codemods/src/main/resources/io/codemodder/codemods/LogFailedLoginCodemod/fix_prompt.txt b/core-codemods/src/main/resources/io/codemodder/codemods/LogFailedLoginCodemod/fix_prompt.txt deleted file mode 100644 index 520bf98e0..000000000 --- a/core-codemods/src/main/resources/io/codemodder/codemods/LogFailedLoginCodemod/fix_prompt.txt +++ /dev/null @@ -1,4 +0,0 @@ -- Log a message at the start of the block handling failed login attempts. -- If the username from the failed login attempt is in the block's scope, include it in the message. -- Use the provided logger in the matching style if it is present in the file, and use the warning log level. -- If no logger is used in the file, use `System.out.println`. diff --git a/core-codemods/src/main/resources/io/codemodder/codemods/LogFailedLoginCodemod/threat_prompt.txt b/core-codemods/src/main/resources/io/codemodder/codemods/LogFailedLoginCodemod/threat_prompt.txt deleted file mode 100644 index 05fb62058..000000000 --- a/core-codemods/src/main/resources/io/codemodder/codemods/LogFailedLoginCodemod/threat_prompt.txt +++ /dev/null @@ -1,18 +0,0 @@ -This security threat only applies to code that validates user login credentials (e.g. a username and password). Code is vulnerable to this threat if it does not log failed login attempts. - -Logging failed login attempts is important for security reasons, as it can help detect and prevent brute force attacks and other malicious activities. - -In your analysis, please answer the following: -- Does the code include login validation? -- What logging framework does the code use? Does the code log any messages when login validation fails? -- Does the code print any messages to the console when login validation fails? -- Does the code throw an exception when login validation fails? If so, what type of exception? Throwing a type of login exception may be considered logging the failed login attempt. - -Examples of LOW risk: -- code does not validate user login credentials -- code logs a message at INFO severity or higher when login validation fails -- code prints to the console using `System.out.println` when login validation fails -- code throws a type of login exception when login validation fails - -Examples of HIGH risk: -- when login validation fails, code does not log or print a message or throw a type of login exception diff --git a/core-codemods/src/test/java/io/codemodder/codemods/LogFailedLoginCodemodTest.java b/core-codemods/src/test/java/io/codemodder/codemods/LogFailedLoginCodemodTest.java index 7ec69c420..79e478003 100644 --- a/core-codemods/src/test/java/io/codemodder/codemods/LogFailedLoginCodemodTest.java +++ b/core-codemods/src/test/java/io/codemodder/codemods/LogFailedLoginCodemodTest.java @@ -5,12 +5,45 @@ import io.codemodder.testutils.Metadata; import org.junit.jupiter.api.Disabled; +/** + * Tests for the {@link LogFailedLoginCodemod}. + * + *

Test cases that should not have code changes: + * + *

+ *
safe/AuthProvider.java.before + *
Describes a type that performs authentication, but no authentication implemented here. + *
safe/JaaSAuthenticationBroker.java.before + *
Throws exceptions that indicate failed login attempts. + *
safe/LoginServlet.java.before + *
logs authentication failures at the WARN level. + *
safe/Main.java.before + *
Logs a message when authentication fails. + *
safe/MainPrint.before + *
prints to the console when a login attempt fails. + *
safe/Queue.java.before + *
is too large to be analyzed. + *
+ * + * Test cases that should have code changes: + * + *
+ *
unsafe/LoginServlet.java.before + *
lacks a log statement before returning unauthorized response + *
unsafe/LoginValidate.java.before + *
lacks a print statement before redirecting to error page. + *
unsafe/MainFame.java.before + *
lacks a log statement before showing the dialog. + *
unsafe/SaltedHashLoginModule + *
lacks a log statement before returning the authenticated decision. That is the correct + * place to log, because it has the username in scope. + */ @Metadata( codemodType = LogFailedLoginCodemod.class, testResourceDir = "log-failed-login", dependencies = {}) @OpenAIIntegrationTest -@Disabled("codemod is in disrepair") +@Disabled("codemod in disrepair - behavior is too indeterminate") public final class LogFailedLoginCodemodTest implements LLMVerifyingCodemodTestMixin { @Override diff --git a/core-codemods/src/test/resources/log-failed-login/vulnerable/MainFrame.java.after b/core-codemods/src/test/resources/log-failed-login/vulnerable/MainFrame.java.after index 638733d40..9d96d8202 100644 --- a/core-codemods/src/test/resources/log-failed-login/vulnerable/MainFrame.java.after +++ b/core-codemods/src/test/resources/log-failed-login/vulnerable/MainFrame.java.after @@ -154,7 +154,7 @@ ResultSet rs = stmt.executeQuery(sql); this.dispose(); } else{ - Logger.getLogger(MainFrame.class.getName()).log(Level.WARNING, "Failed login attempt for user: " + uname1); + Logger.getLogger(MainFrame.class.getName()).log(Level.INFO, "Failed login attempt for user: " + uname1); JOptionPane.showMessageDialog(null, "Incorrect Username Or Password", "Login Failed", 2); } @@ -211,6 +211,3 @@ ResultSet rs = stmt.executeQuery(sql); private javax.swing.JTextField upass; // End of variables declaration//GEN-END:variables } - - - diff --git a/core-codemods/src/test/resources/log-failed-login/vulnerable/SaltedHashLoginModule.java.after b/core-codemods/src/test/resources/log-failed-login/vulnerable/SaltedHashLoginModule.java.after new file mode 100644 index 000000000..cdd633ac3 --- /dev/null +++ b/core-codemods/src/test/resources/log-failed-login/vulnerable/SaltedHashLoginModule.java.after @@ -0,0 +1,326 @@ +package de.meetwithfriends.security.jaas; + +import de.meetwithfriends.security.jaas.principal.RolePrincipal; +import de.meetwithfriends.security.jaas.principal.UserPrincipal; +import de.meetwithfriends.security.jdbc.AuthenticationDao; +import de.meetwithfriends.security.jdbc.JdbcAuthenticationService; +import de.meetwithfriends.security.jdbc.data.ConfigurationData; +import de.meetwithfriends.security.util.StringUtil; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.security.Principal; +import java.text.MessageFormat; +import java.util.List; +import java.util.Map; +import javax.security.auth.Subject; +import javax.security.auth.callback.*; +import javax.security.auth.login.FailedLoginException; +import javax.security.auth.login.LoginException; +import javax.security.auth.spi.LoginModule; + +public class SaltedHashLoginModule implements LoginModule +{ + private static final Logger LOG = LoggerFactory.getLogger(SaltedHashLoginModule.class); + private static final String ROLE_PRINCIPAL_CLASS_OPTION = "role-principal-class"; + + private Subject subject; + private CallbackHandler callbackHandler; + private Map options; + private boolean debug = false; + + private AuthenticationDao authenticationDao; + + private UserPrincipal userPrincipal; + + private String username; + private char[] password; + + private boolean succeeded = false; + private boolean commitSucceeded = false; + private String rolePrincipalClass; + + public static void main(String[] args) throws Exception { + if (null == args || args.length == 0 || args[0].length() == 0) { + LOG.error("need a password arg"); + throw new IllegalArgumentException("Need a password arg"); + } + + String salt = StringUtil.getRandomHexString(args.length > 1 ? Integer.parseInt(args[1]) : 32); + String password = JdbcAuthenticationService.getSaltedPasswordDigest(args[0], salt, "SHA-256"); + System.out.println(" Password -> " + password); + System.out.println(" Salt -> " + salt); + } + + @Override + public void initialize(Subject subject, CallbackHandler callbackHandler, Map sharedState, Map options) + { + this.subject = subject; + this.callbackHandler = callbackHandler; + this.options = options; + this.rolePrincipalClass = (String) options.get(ROLE_PRINCIPAL_CLASS_OPTION); + + initDebugging(); + initAuthenticationDao(); + } + + @Override + public boolean login() throws LoginException + { + if (debug) + { + LOG.debug("login..."); + } + + if (callbackHandler == null) + { + LOG.error("Error: no CallbackHandler available"); + throw new LoginException("Error: no CallbackHandler available"); + } + + Callback[] callbacks = new Callback[2]; + callbacks[0] = new NameCallback("username"); + callbacks[1] = new PasswordCallback("password", false); + + succeeded = requestCredentialsAndAuthenticate(callbacks); + + if (!succeeded) + { + LOG.info("login failed"); + throw new FailedLoginException("login failed"); + } + + if (debug) + { + LOG.debug("authentication succeeded"); + } + + return succeeded; + } + + @Override + public boolean commit() throws LoginException + { + if (debug) + { + LOG.debug("committing authentication"); + } + + if (!succeeded) + { + return false; + } + + userPrincipal = new UserPrincipal(username); + addNonExistentPrincipal(userPrincipal); + + List roles = authenticationDao.loadRoles(username); + for (String role : roles) + { + Principal rolePrincipal = createRole(role); + addNonExistentPrincipal(rolePrincipal); + } + + username = null; + commitSucceeded = true; + + return true; + } + + private Principal createRole(String roleName) { + if (rolePrincipalClass != null && rolePrincipalClass.length() > 0) { + try { + Class clazz = (Class) Class.forName(rolePrincipalClass); + Constructor clazzDeclaredConstructor = clazz.getDeclaredConstructor(String.class); + return clazzDeclaredConstructor.newInstance(roleName); + } catch (Exception e) { + LOG.warn("Unable to create instance of class {}, error is {}", rolePrincipalClass, e.getMessage()); + } + } + + return new RolePrincipal(roleName); + } + + @Override + public boolean abort() throws LoginException + { + if (debug) + { + LOG.debug("aborting authentication"); + } + + if (!succeeded) + { + return false; + } + + if (!commitSucceeded) + { + resetData(); + succeeded = false; + } + else + { + logout(); + } + + return true; + } + + @Override + public boolean logout() throws LoginException + { + if (debug) + { + LOG.debug("Removing principal"); + } + + subject.getPrincipals().remove(userPrincipal); + resetData(); + + succeeded = false; + succeeded = commitSucceeded; + + return true; + } + + private void initDebugging() + { + String debugOption = (String) options.get("debug"); + if ("true".equalsIgnoreCase(debugOption)) + { + debug = true; + } + } + + private void initAuthenticationDao() + { + ConfigurationData configurationData = loadConfigurationData(); + authenticationDao = new AuthenticationDao(configurationData); + } + + private boolean requestCredentialsAndAuthenticate(Callback[] callbacks) throws LoginException + { + boolean authenticated = false; + + try + { + callbackHandler.handle(callbacks); + + username = ((NameCallback) callbacks[0]).getName(); + password = loadPassword((PasswordCallback) callbacks[1]); + if (username == null || password.length == 0) + { + LOG.error("Callback handler does not return login data properly"); + throw new LoginException("Callback handler does not return login data properly"); + } + + JdbcAuthenticationService authService = initAuthenticationService(); + authenticated = authService.authenticate(username, new String(password)); + if (authenticated) + { + LOG.info("User " + username + " successfully authenticated"); + } + } + catch (IOException ex) + { + LOG.error("Error during user login", ex); + throw new LoginException(ex.toString()); + } + catch (UnsupportedCallbackException ex) + { + String msg = MessageFormat.format("{0} not available to garner authentication information from the user", ex + .getCallback().toString()); + + LOG.error(msg); + throw new LoginException("Error: " + msg); + } + + return authenticated; + } + + private char[] loadPassword(PasswordCallback pwCallback) + { + char[] tmpPassword = pwCallback.getPassword(); + pwCallback.clearPassword(); + + if (tmpPassword == null) + { + tmpPassword = new char[0]; + } + + return copyPassword(tmpPassword); + } + + private char[] copyPassword(char[] tmpPassword) + { + char[] copiedPassword = new char[tmpPassword.length]; + System.arraycopy(tmpPassword, 0, copiedPassword, 0, tmpPassword.length); + + return copiedPassword; + } + + private JdbcAuthenticationService initAuthenticationService() + { + String mdAlgorithm = (String) options.get("mdAlgorithm"); + + JdbcAuthenticationService authService = new JdbcAuthenticationService(authenticationDao, debug); + if (mdAlgorithm != null) + { + authService.setMdAlgorithm(mdAlgorithm); + } + + return authService; + } + + private ConfigurationData loadConfigurationData() + { + String dBUser = (String) options.get("dbUser"); + String dBPassword = (String) options.get("dbPassword"); + String dBUrl = (String) options.get("dbURL"); + String dBDriver = (String) options.get("dbDriver"); + String userQuery = (String) options.get("userQuery"); + String roleQuery = (String) options.get("roleQuery"); + + ConfigurationData configurationData = new ConfigurationData(); + configurationData.setDbUser(dBUser); + configurationData.setDbPassword(dBPassword); + configurationData.setDbUrl(dBUrl); + configurationData.setDbDriver(dBDriver); + configurationData.setUserQuery(userQuery); + configurationData.setRoleQuery(roleQuery); + + return configurationData; + + } + + private void addNonExistentPrincipal(Principal principal) + { + if (!subject.getPrincipals().contains(principal)) + { + if (debug) + { + LOG.debug("Adding principal: " + principal); + } + + subject.getPrincipals().add(principal); + } + } + + private void resetData() + { + userPrincipal = null; + username = null; + + if (password != null) + { + for (int i = 0; i < password.length; i++) + { + password[i] = ' '; + password = null; + } + } + } +} diff --git a/core-codemods/src/test/resources/log-failed-login/safe/SaltedHashLoginModule.java.before b/core-codemods/src/test/resources/log-failed-login/vulnerable/SaltedHashLoginModule.java.before similarity index 100% rename from core-codemods/src/test/resources/log-failed-login/safe/SaltedHashLoginModule.java.before rename to core-codemods/src/test/resources/log-failed-login/vulnerable/SaltedHashLoginModule.java.before diff --git a/core-codemods/src/test/resources/logback-test.xml b/core-codemods/src/test/resources/logback-test.xml new file mode 100644 index 000000000..32aee72e1 --- /dev/null +++ b/core-codemods/src/test/resources/logback-test.xml @@ -0,0 +1,9 @@ + + + + + + + + + diff --git a/framework/codemodder-testutils/src/main/java/io/codemodder/testutils/CodemodTestMixin.java b/framework/codemodder-testutils/src/main/java/io/codemodder/testutils/CodemodTestMixin.java index 279f00081..064c8de28 100644 --- a/framework/codemodder-testutils/src/main/java/io/codemodder/testutils/CodemodTestMixin.java +++ b/framework/codemodder-testutils/src/main/java/io/codemodder/testutils/CodemodTestMixin.java @@ -18,6 +18,7 @@ import java.nio.file.StandardCopyOption; import java.util.*; import java.util.function.Function; +import java.util.function.Predicate; import java.util.stream.Collectors; import java.util.stream.Stream; import org.assertj.core.api.Assertions; @@ -85,7 +86,11 @@ default Stream generateTestCases(@TempDir final Path tmpDir) throws metadata.sonarIssuesJsonFiles(), metadata.sonarHotspotsJsonFiles()); }; - return DynamicTest.stream(inputStream, displayNameGenerator, testExecutor); + + final Predicate displayNameFilter = + metadata.only().isEmpty() ? s -> true : s -> s.matches(metadata.only()); + return DynamicTest.stream(inputStream, displayNameGenerator, testExecutor) + .filter(test -> displayNameFilter.test(test.getDisplayName())); } private void verifyCodemod( @@ -261,7 +266,7 @@ private void verifyCodemod( EncodingDetector.create()); CodeTFResult result2 = executor2.execute(List.of(pathToJavaFile)); List changeset2 = result2.getChangeset(); - assertThat(changeset2.size(), is(0)); + assertThat(changeset2, hasSize(0)); String codeAfterSecondTransform = Files.readString(pathToJavaFile); assertThat(codeAfterFirstTransform, equalTo(codeAfterSecondTransform)); diff --git a/framework/codemodder-testutils/src/main/java/io/codemodder/testutils/Metadata.java b/framework/codemodder-testutils/src/main/java/io/codemodder/testutils/Metadata.java index 9281dd357..440773dc5 100644 --- a/framework/codemodder-testutils/src/main/java/io/codemodder/testutils/Metadata.java +++ b/framework/codemodder-testutils/src/main/java/io/codemodder/testutils/Metadata.java @@ -54,4 +54,19 @@ /** Sonar hotspots file names for testing multiple json files */ String[] sonarHotspotsJsonFiles() default {}; + + /** + * Used to filter test execution to only the tests with a display name that matches the given + * regex. This is a test-driven development tool for iterating on a single, dynamic test case. + * + *
+   * @Metadata(
+   *   codemodType = LogFailedLoginCodemod.class,
+   *   testResourceDir = "log-failed-login",
+   *   only = "\\/safe\\/.*",
+   *   dependencies = {})
+   * public final class LogFailedLoginCodemodTest implements CodemodTestMixin {
+   * 
+ */ + String only() default ""; } diff --git a/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/BinaryThreatAnalysisAndFix.java b/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/BinaryThreatAnalysisAndFix.java index c30d55f25..d562fa714 100644 --- a/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/BinaryThreatAnalysisAndFix.java +++ b/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/BinaryThreatAnalysisAndFix.java @@ -18,4 +18,19 @@ public String getFix() { public String getFixDescription() { return fixDescription; } + + @Override + public String toString() { + return "BinaryThreatAnalysisAndFix: \n" + + "\trisk: " + + getRisk() + + "\n" + + "\tanalysis: " + + getAnalysis() + + "\n" + + "\tfix-description: " + + fixDescription + + "\n" + + fix; + } } diff --git a/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/CodeChangingLLMRemediationOutcome.java b/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/CodeChangingLLMRemediationOutcome.java new file mode 100644 index 000000000..fba0e757d --- /dev/null +++ b/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/CodeChangingLLMRemediationOutcome.java @@ -0,0 +1,19 @@ +package io.codemodder.plugins.llm; + +import java.util.Objects; + +/** Models the parameters for a remediation analysis + actual direction for changing the code. */ +public record CodeChangingLLMRemediationOutcome(String key, String description, String fix) + implements LLMRemediationOutcome { + + public CodeChangingLLMRemediationOutcome { + Objects.requireNonNull(key); + Objects.requireNonNull(description); + Objects.requireNonNull(fix); + } + + @Override + public boolean shouldApplyCodeChanges() { + return true; + } +} diff --git a/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/LLMRemediationOutcome.java b/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/LLMRemediationOutcome.java new file mode 100644 index 000000000..3c275b4da --- /dev/null +++ b/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/LLMRemediationOutcome.java @@ -0,0 +1,17 @@ +package io.codemodder.plugins.llm; + +/** Describes a possible remediation outcome. */ +public interface LLMRemediationOutcome { + + /** A small, unique key that identifies this outcome. */ + String key(); + + /** A description of the code that the LLM will attempt to use to match. */ + String description(); + + /** A description of the fix for cases that match this description. */ + String fix(); + + /** Whether this outcome should lead to a code change. */ + boolean shouldApplyCodeChanges(); +} diff --git a/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/Model.java b/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/Model.java new file mode 100644 index 000000000..e65e5d0fd --- /dev/null +++ b/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/Model.java @@ -0,0 +1,30 @@ +package io.codemodder.plugins.llm; + +import com.theokanning.openai.completion.chat.ChatMessage; +import java.util.List; + +/** + * Internal model for a GPT language model. Helps to colocate model-specific logic e.g. token + * counting. + */ +public interface Model { + + /** + * @return well-known model ID e.g. gpt-3.5-turbo-0125 + */ + String id(); + + /** + * @return maximum size of the context window supported by this model + */ + int contextWindow(); + + /** + * Estimates the number of tokens the messages will consume when passed to this model. The + * estimate can vary based on the model. + * + * @param messages the list of messages for which to estimate token usage + * @return estimated tokens that would be consumed by the model + */ + int tokens(List messages); +} diff --git a/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/NoActionLLMRemediationOutcome.java b/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/NoActionLLMRemediationOutcome.java new file mode 100644 index 000000000..71207fd77 --- /dev/null +++ b/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/NoActionLLMRemediationOutcome.java @@ -0,0 +1,23 @@ +package io.codemodder.plugins.llm; + +import java.util.Objects; + +/** Models the parameters for a remediation analysis that results in no code changes. */ +public record NoActionLLMRemediationOutcome(String key, String description) + implements LLMRemediationOutcome { + + public NoActionLLMRemediationOutcome { + Objects.requireNonNull(key); + Objects.requireNonNull(description); + } + + @Override + public String fix() { + return "N/A"; + } + + @Override + public boolean shouldApplyCodeChanges() { + return false; + } +} diff --git a/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/SarifToLLMForBinaryVerificationAndFixingCodemod.java b/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/SarifToLLMForBinaryVerificationAndFixingCodemod.java index 026ff61f8..b661840e4 100644 --- a/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/SarifToLLMForBinaryVerificationAndFixingCodemod.java +++ b/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/SarifToLLMForBinaryVerificationAndFixingCodemod.java @@ -1,6 +1,6 @@ package io.codemodder.plugins.llm; -import static io.codemodder.plugins.llm.Tokens.countTokens; +import static io.codemodder.plugins.llm.StandardModel.GPT_3_5_TURBO; import com.contrastsecurity.sarif.Region; import com.contrastsecurity.sarif.Result; @@ -34,14 +34,22 @@ public abstract class SarifToLLMForBinaryVerificationAndFixingCodemod extends SarifPluginRawFileChanger { - private static final Logger logger = - LoggerFactory.getLogger(SarifToLLMForBinaryVerificationAndFixingCodemod.class); private final OpenAIService openAI; + private final Model model; protected SarifToLLMForBinaryVerificationAndFixingCodemod( - final RuleSarif sarif, final OpenAIService openAI) { + final RuleSarif sarif, final OpenAIService openAI, final Model model) { super(sarif); this.openAI = Objects.requireNonNull(openAI); + this.model = Objects.requireNonNull(model); + } + + /** + * For backwards compatibility with a previous version of this API, uses a GPT 3.5 Turbo model. + */ + protected SarifToLLMForBinaryVerificationAndFixingCodemod( + final RuleSarif sarif, final OpenAIService openAI) { + this(sarif, openAI, GPT_3_5_TURBO); } @Override @@ -51,7 +59,7 @@ public CodemodFileScanningResult onFileFound( // For fine-tuning the semgrep rule, debug log the matching snippets in the file. results.forEach( - (result) -> { + result -> { Region region = result.getLocations().get(0).getPhysicalLocation().getRegion(); logger.debug("{}:{}", region.getStartLine(), region.getSnippet().getText()); }); @@ -68,10 +76,7 @@ public CodemodFileScanningResult onFileFound( } BinaryThreatAnalysisAndFix fix = fixThreat(file, context, results); - logger.debug("risk: {}", fix.getRisk()); - logger.debug("analysis: {}", fix.getAnalysis()); - logger.debug("fix: {}", fix.getFix()); - logger.debug("fix description: {}", fix.getFixDescription()); + logger.debug("{}", fix); // If our second look determined that the risk of the threat is low, don't change the file. if (fix.getRisk() == BinaryThreatRisk.LOW) { @@ -79,7 +84,7 @@ public CodemodFileScanningResult onFileFound( } // If the LLM was unable to fix the threat, don't change the file. - if (fix.getFix() == null || fix.getFix().length() == 0) { + if (fix.getFix() == null || fix.getFix().isEmpty()) { logger.info("unable to fix: {}", context.path()); return CodemodFileScanningResult.none(); } @@ -89,7 +94,7 @@ public CodemodFileScanningResult onFileFound( // Ensure the end result isn't wonky. Patch patch = DiffUtils.diff(file.getLines(), fixedLines); - if (patch.getDeltas().size() == 0 || !isPatchExpected(patch)) { + if (patch.getDeltas().isEmpty() || !isPatchExpected(patch)) { logger.error("unexpected patch: {}", patch); return CodemodFileScanningResult.none(); } @@ -140,11 +145,10 @@ private BinaryThreatAnalysis analyzeThreat( ChatMessage systemMessage = getSystemMessage(context, results); ChatMessage userMessage = getAnalyzeUserMessage(file); - int tokenCount = countTokens(List.of(systemMessage, userMessage)); - if (tokenCount > 3796) { - // The max tokens for gpt-3.5-turbo-0613 is 4,096. If the estimated token count, which doesn't - // include the function (~100 tokens) or the reply (~200 tokens), is close to the max, assume - // the code is safe (for now). + // If the estimated token count, which doesn't include the function (~100 tokens) or the reply + // (~200 tokens), is close to the max, then assume the code is safe (for now). + int tokenCount = model.tokens(List.of(systemMessage, userMessage)); + if (tokenCount > model.contextWindow() - 300) { return new BinaryThreatAnalysis( "Ignoring file: estimated prompt token count (" + tokenCount + ") is too high.", BinaryThreatRisk.LOW); @@ -152,8 +156,7 @@ private BinaryThreatAnalysis analyzeThreat( logger.debug("estimated prompt token count: {}", tokenCount); } - return getLLMResponse( - "gpt-3.5-turbo-0613", 0.2D, systemMessage, userMessage, BinaryThreatAnalysis.class); + return getLLMResponse(model.id(), 0.0D, systemMessage, userMessage, BinaryThreatAnalysis.class); } private BinaryThreatAnalysisAndFix fixThreat( @@ -161,7 +164,7 @@ private BinaryThreatAnalysisAndFix fixThreat( final CodemodInvocationContext context, final List results) { return getLLMResponse( - "gpt-4-0613", + model.id(), 0D, getSystemMessage(context, results), getFixUserMessage(file), @@ -194,7 +197,7 @@ private T getLLMResponse( .build(); ChatCompletionResult result = openAI.createChatCompletion(request); - logger.debug(result.getUsage().toString()); + logger.debug("{}", result.getUsage()); ChatMessage response = result.getChoices().get(0).getMessage(); return functionExecutor.execute(response.getFunctionCall()); @@ -242,18 +245,42 @@ private ChatMessage getFixUserMessage(final FileDescription file) { private static final String FIX_USER_MESSAGE_TEMPLATE = """ - A file with line numbers is provided below. Analyze it. If the risk is HIGH, use these rules \ - to make the MINIMUM number of changes necessary to reduce the file's risk to LOW: + A file with line numbers is provided below. Analyze it. If the risk is HIGH, use these rules to make the MINIMUM number of changes necessary to reduce the file's risk to LOW: - Each change MUST be syntactically correct. - - DO NOT change the file's formatting or comments. %s - Create a diff patch for the changed file, using the unified format with a header. Include \ - the diff patch and a summary of the changes with your threat analysis. - - Save your threat analysis. + Any code changes to reduce the file's risk to LOW must be stored in a diff patch format. Follow these instructions when creating the patch: + - Your output must be in the form a unified diff patch that will be applied by your coworkers. + - The output must be similar to the output of `diff -U0`. Do not include line number ranges. + - Start each hunk of changes with a `@@ ... @@` line. + - Each change in a file should be a separate hunk in the diff. + - It is very important for the change to contain only what is minimally required to fix the problem. + - Remember that whitespace and indentation changes can be important. Preserve the original formatting and indentation. Do not replace tabs with spaces or vice versa. If the original code uses tabs, use tabs in the patch. Encode tabs using a tab literal (\\\\t). If the original code uses spaces, use spaces in the patch. Do not add spaces where none were present in the original code. **THIS IS ESPECIALLY IMPORTANT AT THE BEGINNING OF DIFF LINES.** + - The unified diff must be accurate and complete. + - The unified diff will be applied to the source code by your coworkers. + + Here's an example of a unified diff: + ```diff + --- a/file.txt + +++ b/file.txt + @@ ... @@ + for (var i = 0; i < array.length; i++) { + This line is unchanged. + - This is the original line + + This is the replacement line + } + Here is another unchanged line. + @@ ... @@ + -This line has been removed but not replaced. + This line is unchanged. + ``` + + Now save your threat analysis. --- %s %s """; + + private static final Logger logger = + LoggerFactory.getLogger(SarifToLLMForBinaryVerificationAndFixingCodemod.class); } diff --git a/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/SarifToLLMForMultiOutcomeCodemod.java b/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/SarifToLLMForMultiOutcomeCodemod.java new file mode 100644 index 000000000..4abdad288 --- /dev/null +++ b/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/SarifToLLMForMultiOutcomeCodemod.java @@ -0,0 +1,484 @@ +package io.codemodder.plugins.llm; + +import com.contrastsecurity.sarif.Region; +import com.contrastsecurity.sarif.Result; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.fasterxml.jackson.annotation.JsonPropertyDescription; +import com.github.difflib.DiffUtils; +import com.github.difflib.patch.Patch; +import com.theokanning.openai.completion.chat.*; +import com.theokanning.openai.completion.chat.ChatCompletionRequest.ChatCompletionRequestFunctionCall; +import com.theokanning.openai.service.FunctionExecutor; +import io.codemodder.*; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.util.*; +import java.util.function.Function; +import java.util.stream.Collectors; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * An extension of {@link SarifPluginRawFileChanger} that uses large language models (LLMs) to + * perform some analysis and categorize what's found to drive different potential code changes. + * + *

The inspiration for this type was the "remediate something found by tool X" use case. For + * example, if a tool cites a vulnerability on a given line, we may want to take any of the + * following actions: + * + *

    + *
  • Fix the identified issue by doing A + *
  • Fix the identified issue by doing B + *
  • Add a suppression comment to the given line since it's likely a false positive + *
  • Refactor the code so it doesn't trip the rule anymore, without actually "fixing it" + *
  • Do nothing, since the LLM can't determine which case the code is + *
+ * + *

To accomplish that, we need the analysis to "bucket" the code into one of the above + * categories. + */ +public abstract class SarifToLLMForMultiOutcomeCodemod extends SarifPluginRawFileChanger { + + private static final Logger logger = + LoggerFactory.getLogger(SarifToLLMForMultiOutcomeCodemod.class); + private final OpenAIService openAI; + private final List remediationOutcomes; + private final Model categorizationModel; + private final Model codeChangingModel; + + protected SarifToLLMForMultiOutcomeCodemod( + final RuleSarif sarif, + final OpenAIService openAI, + final List remediationOutcomes) { + this(sarif, openAI, remediationOutcomes, StandardModel.GPT_4O, StandardModel.GPT_4); + } + + protected SarifToLLMForMultiOutcomeCodemod( + final RuleSarif sarif, + final OpenAIService openAI, + final List remediationOutcomes, + final Model categorizationModel, + final Model codeChangingModel) { + super(sarif); + this.openAI = Objects.requireNonNull(openAI); + this.remediationOutcomes = Objects.requireNonNull(remediationOutcomes); + if (remediationOutcomes.size() < 2) { + throw new IllegalArgumentException("must have 2+ remediation outcome"); + } + this.categorizationModel = Objects.requireNonNull(categorizationModel); + this.codeChangingModel = Objects.requireNonNull(codeChangingModel); + } + + @Override + public CodemodFileScanningResult onFileFound( + final CodemodInvocationContext context, final List results) { + logger.info("processing: {}", context.path()); + + List changes = new ArrayList<>(); + for (Result result : results) { + Optional change = processResult(context, result); + change.ifPresent(changes::add); + } + return CodemodFileScanningResult.withOnlyChanges(List.copyOf(changes)); + } + + private Optional processResult( + final CodemodInvocationContext context, final Result result) { + // short-circuit if the code is too long + if (estimatedToExceedContextWindow(context)) { + logger.debug("code too long: {}", context.path()); + return Optional.empty(); + } + try { + FileDescription file = FileDescription.from(context.path()); + + final CategorizeResponse analysis = categorize(file, result); + String outcomeKey = analysis.getOutcomeKey(); + logger.debug("outcomeKey: {}", outcomeKey); + logger.debug("analysis: {}", analysis.getAnalysis()); + if (outcomeKey == null || outcomeKey.isBlank()) { + logger.debug("unable to determine outcome"); + return Optional.empty(); + } + Optional outcome = + remediationOutcomes.stream() + .filter(oc -> oc.key().equals(analysis.outcomeKey)) + .findFirst(); + if (outcome.isEmpty()) { + logger.debug("unable to find outcome for key: {}", analysis.outcomeKey); + return Optional.empty(); + } + LLMRemediationOutcome matchedOutcome = outcome.get(); + logger.debug("outcomeKey: {}", matchedOutcome.key()); + logger.debug("description: {}", matchedOutcome.description()); + if (!matchedOutcome.shouldApplyCodeChanges()) { + logger.debug("Matched outcome suggests there should be no code changes"); + return Optional.empty(); + } + + CodeChangeResponse response = changeCode(file, result); + logger.debug("outcome: {}", response.outcomeKey); + logger.debug("analysis: {}", response.codeChange); + + // If our second look determined that there are no outcomes associated with code changes, we + // should quit + if (response.outcomeKey == null || outcomeKey.isEmpty()) { + logger.debug("No outcomes detected"); + return Optional.empty(); + } + + List codeChangingOutcomeKeys = + remediationOutcomes.stream() + .filter(LLMRemediationOutcome::shouldApplyCodeChanges) + .map(LLMRemediationOutcome::key) + .toList(); + + boolean anyRequireCodeChanges = codeChangingOutcomeKeys.contains(response.outcomeKey); + if (!anyRequireCodeChanges) { + logger.debug("On second analysis, outcomes require no code changes"); + return Optional.empty(); + } + + String codeChange = response.codeChange; + // If the LLM was unable to fix the threat, don't change the file. + if (codeChange == null || codeChange.length() == 0) { + logger.info("unable to fix because diff not present: {}", context.path()); + return Optional.empty(); + } + + // Apply the fix. + List fixedLines = LLMDiffs.applyDiff(file.getLines(), codeChange); + + // Ensure the end result isn't wonky. + Patch patch = DiffUtils.diff(file.getLines(), fixedLines); + if (patch.getDeltas().size() == 0) { + logger.error("empty patch: {}", patch); + return Optional.empty(); + } + + try { + // Replace the file with the fixed version. + String fixedFile = String.join(file.getLineSeparator(), fixedLines); + Files.writeString(context.path(), fixedFile, file.getCharset()); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + // Report all the changes at their respective line number + + return Optional.of(createCodemodChange(result, response.line, response.fixDescription)); + } catch (Exception e) { + logger.error("failed to process: {}", context.path(), e); + throw e; + } + } + + /** + * Estimates if the code is too long to analyze within the context windows of the given models. + * This is only an estimate: the actual token count may be higher but won't be lower. + * + * @param context the current codemod invocation context + * @return true when the prompts are estimated to exceed the context window for the models used in + * this codemod + */ + private boolean estimatedToExceedContextWindow(final CodemodInvocationContext context) { + // in both the threat analysis and code fix cases, the estimated user message size is dominated + // by the code snippet, so use the code snippets as the floor + final var estimatedUserMessage = + new ChatMessage(ChatMessageRole.USER.value(), context.contents()); + for (final var model : List.of(categorizationModel, codeChangingModel)) { + int tokenCount = model.tokens(List.of(getSystemMessage(), estimatedUserMessage)); + // estimated token count doesn't include the function (~100 tokens) or the reply + // (~200 tokens) so add those estimates before checking against window size + tokenCount += 300; + if (tokenCount > model.contextWindow()) { + return true; + } + } + return false; + } + + /** + * Create a {@link CodemodChange} from the given code change data. + * + * @param line the line number of the change + * @param fixDescription the description of the change + */ + protected CodemodChange createCodemodChange( + final Result result, final int line, final String fixDescription) { + return CodemodChange.from(line, fixDescription); + } + + /** + * Instructs the LLM on how to assess the risk of the threat. + * + * @return The prompt. + */ + protected abstract String getThreatPrompt(); + + private CategorizeResponse categorize(final FileDescription file, final Result result) { + ChatMessage systemMessage = getSystemMessage(); + ChatMessage userMessage = getCategorizationUserMessage(file, result); + return getCategorizationResponse(systemMessage, userMessage); + } + + private CodeChangeResponse changeCode(final FileDescription file, final Result result) { + return getCodeChangeResponse(getSystemMessage(), getChangeCodeMessage(file, result)); + } + + private CategorizeResponse getCategorizationResponse( + final ChatMessage systemMessage, final ChatMessage userMessage) { + // Create a function to get the LLM to return a structured response. + ChatFunction function = + ChatFunction.builder() + .name("save_categorization_analysis") + .description("Saves a categorization analysis.") + .executor( + CategorizeResponse.class, + c -> c) // Return the `responseClass` instance when executed. + .build(); + + FunctionExecutor functionExecutor = new FunctionExecutor(Collections.singletonList(function)); + + ChatCompletionRequest request = + ChatCompletionRequest.builder() + .model(categorizationModel.id()) + .messages(List.of(systemMessage, userMessage)) + .functions(functionExecutor.getFunctions()) + .functionCall(ChatCompletionRequestFunctionCall.of(function.getName())) + .temperature(0.0) + .build(); + + ChatCompletionResult result = openAI.createChatCompletion(request); + logger.debug(result.getUsage().toString()); + + ChatMessage response = result.getChoices().get(0).getMessage(); + return functionExecutor.execute(response.getFunctionCall()); + } + + private CodeChangeResponse getCodeChangeResponse( + final ChatMessage systemMessage, final ChatMessage userMessage) { + // Create a function to get the LLM to return a structured response. + ChatFunction function = + ChatFunction.builder() + .name("save_categorization_analysis_and_code_change") + .description("Saves a categorization, analysis and code change.") + .executor(CodeChangeResponse.class, c -> c) + .build(); + + FunctionExecutor functionExecutor = new FunctionExecutor(Collections.singletonList(function)); + + ChatCompletionRequest request = + ChatCompletionRequest.builder() + .model(codeChangingModel.id()) + .messages(List.of(systemMessage, userMessage)) + .functions(functionExecutor.getFunctions()) + .functionCall(ChatCompletionRequestFunctionCall.of(function.getName())) + .topP(0.1) + .temperature(0.0) + .build(); + + ChatCompletionResult result = openAI.createChatCompletion(request); + logger.debug(result.getUsage().toString()); + + ChatMessage response = result.getChoices().get(0).getMessage(); + return functionExecutor.execute(response.getFunctionCall()); + } + + private ChatMessage getSystemMessage() { + return new ChatMessage( + ChatMessageRole.SYSTEM.value(), + SYSTEM_MESSAGE_TEMPLATE.formatted(getThreatPrompt().strip()).strip()); + } + + /** Analyze a single SARIF result and get feedback. */ + private ChatMessage getCategorizationUserMessage( + final FileDescription file, final Result result) { + Region region = result.getLocations().get(0).getPhysicalLocation().getRegion(); + int line = region.getStartLine(); + Integer column = region.getStartColumn(); + + String outcomeDescriptions = formatOutcomeDescriptions(false); + + return new ChatMessage( + ChatMessageRole.SYSTEM.value(), + CATEGORIZE_CODE_USER_MESSAGE_TEMPLATE + .formatted( + String.valueOf(line), + column != null ? String.valueOf(column) : "(unknown)", + outcomeDescriptions, + file.getFileName(), + file.formatLinesWithLineNumbers()) + .strip()); + } + + /** + * Format the outcome descriptions for sending to the LLM. Should look something like this: + * + *

+   * ===
+   * Outcome: 'assignment_is_redundant':
+   * Description: The variable is assigned and re-assigned to the same value. This is redundant and should be removed.
+   * Code Changes Required: YES
+   * Code Change Directions: Remove the initial assignment.
+   * ===
+   * Outcome: 'assignment_can_be_streamlined':
+   * Description: The variable is created and then assigned in separate adjacent statements.
+   * Code Changes Required: YES
+   * Code Change Directions: Combine the two statements together.
+   * ===
+   * ...
+   * 
+ */ + private String formatOutcomeDescriptions(boolean includeFixes) { + String withFixTemplate = + """ + ============ + Outcome: %s + Description: %s + Code Changes Required: YES + Code Change Directions For Outcome: %s + """; + String withoutFixTemplate = + """ + ============ + Outcome: %s + Description: %s + Code Changes Required: NO + """; + + Function withFixProvider = + (outcome) -> withFixTemplate.formatted(outcome.key(), outcome.description(), outcome.fix()); + Function withoutFixProvider = + (outcome) -> withoutFixTemplate.formatted(outcome.key(), outcome.description()); + return remediationOutcomes.stream() + .map(oc -> includeFixes ? withFixProvider.apply(oc) : withoutFixProvider.apply(oc)) + .collect(Collectors.joining("\n")) + + "\n============"; + } + + /** + * Analyze a single SARIF result, and get the changed file back as well if it warrants change. + * + * @param file the file being analyzed + * @param result the result to analyze + * @return the message to send to the LLM + */ + private ChatMessage getChangeCodeMessage(final FileDescription file, final Result result) { + + Region region = result.getLocations().get(0).getPhysicalLocation().getRegion(); + String regionStr = " Line " + region.getStartLine() + ", column " + region.getStartColumn(); + + String outcomeDescriptions = formatOutcomeDescriptions(true); + return new ChatMessage( + ChatMessageRole.USER.value(), + CHANGE_CODE_USER_MESSAGE_TEMPLATE + .formatted( + regionStr, + outcomeDescriptions, + file.getFileName(), + file.formatLinesWithLineNumbers()) + .strip()); + } + + private static final String SYSTEM_MESSAGE_TEMPLATE = + """ + You are a security analyst bot. You are helping analyze code to assess its risk to a \ + specific security threat. Your code change recommendations are safe and accurate. + %s + """; + + private static final String CATEGORIZE_CODE_USER_MESSAGE_TEMPLATE = + """ + Analyze ONLY line %s, column %s, and discern which "outcome" best describes the code. You should save your categorization analysis. You MUST ignore any other file contents, even if they look like they have issues. + Here are the possible outcomes: + %s + --- %s + %s + """; + + static class CategorizeResponse { + @JsonPropertyDescription("A detailed analysis of how the analysis arrived at the outcome") + @JsonProperty(required = true) + private String analysis; + + @JsonPropertyDescription( + "The category of the analysis, or empty if the analysis could not categorized") + @JsonProperty(required = true) + private String outcomeKey; + + @SuppressWarnings("unused") // needed by Jackson + public CategorizeResponse() {} + + private CategorizeResponse(final String analysis, final String outcomeKey) { + this.analysis = analysis; + this.outcomeKey = outcomeKey; + } + + public String getAnalysis() { + return analysis; + } + + public String getOutcomeKey() { + return outcomeKey; + } + } + + private static final String CHANGE_CODE_USER_MESSAGE_TEMPLATE = + """ + The tool has cited the following location for you to analyze: + %s + Decide which "outcome" you want to place it in. Then, if that outcome requires code changes, make the changes as described in the Code Change Directions and save them. Here are the possible outcomes: + %s + Pick which outcome best describes the code. If you are making code changes, you MUST make the MINIMUM number of changes necessary to fix the issue. + - Each change MUST be syntactically correct. + - DO NOT change the file's formatting or comments. + - Create a diff patch for the changed file if and only if any of the outcomes require code changes. + - The patch must use the unified format with a header. Include the diff patch and a summary of the changes with your analysis. + If you the outcome says you should suppress a Semgrep finding in the code, insert a comment above it and put `// nosemgrep: ` + Save your categorization and code change analysis when you're done. + --- %s + %s + """; + + static final class CodeChangeResponse { + @JsonPropertyDescription( + "The code change a diff patch in unified format. Required if any of the outcome keys indicate a change.") + private String codeChange; + + @SuppressWarnings("MismatchedQueryAndUpdateOfCollection") + @JsonPropertyDescription("The line in the file to which this analysis is related") + private int line; + + @JsonPropertyDescription("The column to which this analysis is related") + private int column; + + @JsonPropertyDescription("The outcome key associated with this particular result location") + private String outcomeKey; + + @JsonPropertyDescription( + "A short description of the code change. Required only if the file needs a change.") + private String fixDescription; + + public String getFixDescription() { + return fixDescription; + } + + public String getOutcomeKey() { + return outcomeKey; + } + + public int getLine() { + return line; + } + + public int getColumn() { + return column; + } + + public String getCodeChange() { + return codeChange; + } + } +} diff --git a/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/StandardModel.java b/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/StandardModel.java new file mode 100644 index 000000000..770b3d116 --- /dev/null +++ b/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/StandardModel.java @@ -0,0 +1,49 @@ +package io.codemodder.plugins.llm; + +import com.knuddels.jtokkit.api.EncodingType; +import com.theokanning.openai.completion.chat.ChatMessage; +import java.util.List; + +/** Well-known GPT models used in Codemod development. */ +public enum StandardModel implements Model { + GPT_3_5_TURBO("gpt-3.5-turbo-0125", 16_385) { + @Override + public int tokens(final List messages) { + return Tokens.countTokens(messages, 3, EncodingType.CL100K_BASE); + } + }, + GPT_4("gpt-4-0613", 8_192) { + @Override + public int tokens(final List messages) { + return Tokens.countTokens(messages, 3, EncodingType.CL100K_BASE); + } + }, + GPT_4O("gpt-4o-2024-05-13", 128_000) { + /** + * This is wrong - we copy / pasted from GPT 3.5 while we await GPT-4o token counting support from upstream utility. + */ + @Override + public int tokens(final List messages) { + return Tokens.countTokens(messages, 3, EncodingType.CL100K_BASE); + } + }; + + private final String id; + private final int contextWindow; + + StandardModel(final String id, final int contextWindow) { + this.id = id; + this.contextWindow = contextWindow; + } + + @Override + public String id() { + return id; + } + + @Override + public int contextWindow() { + return contextWindow; + } +} diff --git a/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/Tokens.java b/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/Tokens.java index 12910acab..cfdf50923 100644 --- a/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/Tokens.java +++ b/plugins/codemodder-plugin-llm/src/main/java/io/codemodder/plugins/llm/Tokens.java @@ -13,8 +13,7 @@ public final class Tokens { private Tokens() {} /** - * Estimates the number of tokens the messages will consume when passed to the {@code - * gpt-3.5-turbo-0613} or {@code gpt-4-0613} models. + * Estimates the number of tokens the messages will consume. * *

This does not yet support estimating the number of tokens the functions will consume, since * the feature is released. * * @param messages The messages. + * @param tokensPerMessage The number of tokens consumed per message by the given model. + * @param encodingType The encoding type used by the model. * @return The number of tokens. * @see How * to count tokens with tiktoken */ - public static int countTokens(final List messages) { + public static int countTokens( + final List messages, + final int tokensPerMessage, + final EncodingType encodingType) { EncodingRegistry registry = Encodings.newDefaultEncodingRegistry(); - Encoding encoding = registry.getEncoding(EncodingType.CL100K_BASE); + Encoding encoding = registry.getEncoding(encodingType); int count = 0; for (ChatMessage message : messages) { - count += 3; // Both gpt-3.5-turbo-0613 and gpt-4-0613 consume 3 tokens per message. + count += tokensPerMessage; count += encoding.countTokens(message.getContent()); count += encoding.countTokens(message.getRole()); } - count += 3; // Every reply is primed with <|start|>assistant<|message|>. + count += tokensPerMessage; // Every reply is primed with <|start|>assistant<|message|>. return count; } diff --git a/plugins/codemodder-plugin-llm/src/testFixtures/java/io/codemodder/plugins/llm/test/LLMVerifyingCodemodTestMixin.java b/plugins/codemodder-plugin-llm/src/testFixtures/java/io/codemodder/plugins/llm/test/LLMVerifyingCodemodTestMixin.java index d6b6d4ad7..4c62459fc 100644 --- a/plugins/codemodder-plugin-llm/src/testFixtures/java/io/codemodder/plugins/llm/test/LLMVerifyingCodemodTestMixin.java +++ b/plugins/codemodder-plugin-llm/src/testFixtures/java/io/codemodder/plugins/llm/test/LLMVerifyingCodemodTestMixin.java @@ -14,7 +14,9 @@ import com.theokanning.openai.completion.chat.ChatMessageRole; import com.theokanning.openai.service.FunctionExecutor; import io.codemodder.EncodingDetector; +import io.codemodder.plugins.llm.Model; import io.codemodder.plugins.llm.OpenAIService; +import io.codemodder.plugins.llm.StandardModel; import io.codemodder.testutils.CodemodTestMixin; import java.io.IOException; import java.nio.charset.Charset; @@ -32,6 +34,13 @@ public interface LLMVerifyingCodemodTestMixin extends CodemodTestMixin { */ String getRequirementsPrompt(); + /** + * @return GPT model to use for the test harness to verify the codemod's changes + */ + default Model model() { + return StandardModel.GPT_4O; + } + @Override default void verifyTransformedCode(final Path before, final Path expected, final Path after) throws IOException { @@ -63,7 +72,7 @@ private Assessment assessChanges( ChatCompletionRequest request = ChatCompletionRequest.builder() - .model("gpt-4o-2024-05-13") + .model(model().id()) .messages( List.of( new ChatMessage(