postgresql-jdbc/CVE-2024-1597-2.patch
starlet-dx 1933fb7ebc Fix CVE-2024-1597
(cherry picked from commit adec29e7542b857ac69f953797d0681bfd35f944)
2024-02-26 17:35:05 +08:00

339 lines
12 KiB
Diff

From fe002b31f2c7dcf7e2fe75fe7fd18df4e4503abf Mon Sep 17 00:00:00 2001
From: Dave Cramer <davecramer@gmail.com>
Date: Tue, 20 Feb 2024 10:01:14 -0500
Subject: [PATCH] Merge pull request from GHSA-24rp-q3w6-vc56
* SQL Injection via line comment generation for 42_4_x
* fix: Add parentheses around NULL parameter values in simple query mode
* simplify code, handle binary and add tests
---------
Co-authored-by: Sehrope Sarkuni <sehrope@jackdb.com>
---
.../core/v3/SimpleParameterList.java | 69 +++++---
.../jdbc/ParameterInjectionTest.java | 155 +++++++++++++-----
2 files changed, 162 insertions(+), 62 deletions(-)
diff --git a/src/main/java/org/postgresql/core/v3/SimpleParameterList.java b/src/main/java/org/postgresql/core/v3/SimpleParameterList.java
index 95cdaa4..d2d27fd 100644
--- a/src/main/java/org/postgresql/core/v3/SimpleParameterList.java
+++ b/src/main/java/org/postgresql/core/v3/SimpleParameterList.java
@@ -202,7 +202,7 @@ class SimpleParameterList implements V3ParameterList {
* {}
* </pre>
**/
- private static String quoteAndCast(String text, String type, boolean standardConformingStrings) {
+ private static String quoteAndCast(String text, /* @Nullable */ String type, boolean standardConformingStrings) {
StringBuilder sb = new StringBuilder((text.length() + 10) / 10 * 11); // Add 10% for escaping.
sb.append("('");
try {
@@ -233,35 +233,47 @@ class SimpleParameterList implements V3ParameterList {
return "?";
} else if (paramValue == NULL_OBJECT) {
return "(NULL)";
- } else if ((flags[index] & BINARY) == BINARY) {
+ }
+ String textValue;
+ String type;
+ if ((flags[index] & BINARY) == BINARY) {
// handle some of the numeric types
-
switch (paramTypes[index]) {
case Oid.INT2:
short s = ByteConverter.int2((byte[]) paramValue, 0);
- return quoteAndCast(Short.toString(s), "int2", standardConformingStrings);
+ textValue = Short.toString(s);
+ type = "int2";
+ break;
case Oid.INT4:
int i = ByteConverter.int4((byte[]) paramValue, 0);
- return quoteAndCast(Integer.toString(i), "int4", standardConformingStrings);
+ textValue = Integer.toString(i);
+ type = "int4";
+ break;
case Oid.INT8:
long l = ByteConverter.int8((byte[]) paramValue, 0);
- return quoteAndCast(Long.toString(l), "int8", standardConformingStrings);
+ textValue = Long.toString(l);
+ type = "int8";
+ break;
case Oid.FLOAT4:
float f = ByteConverter.float4((byte[]) paramValue, 0);
if (Float.isNaN(f)) {
return "('NaN'::real)";
}
- return quoteAndCast(Float.toString(f), "float", standardConformingStrings);
+ textValue = Float.toString(f);
+ type = "real";
+ break;
case Oid.FLOAT8:
double d = ByteConverter.float8((byte[]) paramValue, 0);
if (Double.isNaN(d)) {
return "('NaN'::double precision)";
}
- return quoteAndCast(Double.toString(d), "double precision", standardConformingStrings);
+ textValue = Double.toString(d);
+ type = "double precision";
+ break;
case Oid.NUMERIC:
Number n = ByteConverter.numeric((byte[]) paramValue);
@@ -269,44 +281,55 @@ class SimpleParameterList implements V3ParameterList {
assert ((Double) n).isNaN();
return "('NaN'::numeric)";
}
- return n.toString();
+ textValue = n.toString();
+ type = "numeric";
+ break;
case Oid.UUID:
- String uuid =
+ textValue =
new UUIDArrayAssistant().buildElement((byte[]) paramValue, 0, 16).toString();
- return quoteAndCast(uuid, "uuid", standardConformingStrings);
+ type = "uuid";
+ break;
case Oid.POINT:
PGpoint pgPoint = new PGpoint();
pgPoint.setByteValue((byte[]) paramValue, 0);
- return quoteAndCast(pgPoint.toString(), "point", standardConformingStrings);
+ textValue = pgPoint.toString();
+ type = "point";
+ break;
case Oid.BOX:
PGbox pgBox = new PGbox();
pgBox.setByteValue((byte[]) paramValue, 0);
- return quoteAndCast(pgBox.toString(), "box", standardConformingStrings);
+ textValue = pgBox.toString();
+ type = "box";
+ break;
+
+ default:
+ return "?";
}
- return "?";
} else {
- String param = paramValue.toString();
+ textValue = paramValue.toString();
int paramType = paramTypes[index];
if (paramType == Oid.TIMESTAMP) {
- return quoteAndCast(param, "timestamp", standardConformingStrings);
+ type = "timestamp";
} else if (paramType == Oid.TIMESTAMPTZ) {
- return quoteAndCast(param, "timestamp with time zone", standardConformingStrings);
+ type = "timestamp with time zone";
} else if (paramType == Oid.TIME) {
- return quoteAndCast(param, "time", standardConformingStrings);
+ type = "time";
} else if (paramType == Oid.TIMETZ) {
- return quoteAndCast(param, "time with time zone", standardConformingStrings);
+ type = "time with time zone";
} else if (paramType == Oid.DATE) {
- return quoteAndCast(param, "date", standardConformingStrings);
+ type = "date";
} else if (paramType == Oid.INTERVAL) {
- return quoteAndCast(param, "interval", standardConformingStrings);
+ type = "interval";
} else if (paramType == Oid.NUMERIC) {
- return quoteAndCast(param, "numeric", standardConformingStrings);
+ type = "numeric";
+ } else {
+ type = null;
}
- return quoteAndCast(param, null, standardConformingStrings);
}
+ return quoteAndCast(textValue, type, standardConformingStrings);
}
@Override
diff --git a/src/test/java/org/postgresql/jdbc/ParameterInjectionTest.java b/src/test/java/org/postgresql/jdbc/ParameterInjectionTest.java
index 2a33acd..10c0af3 100644
--- a/src/test/java/org/postgresql/jdbc/ParameterInjectionTest.java
+++ b/src/test/java/org/postgresql/jdbc/ParameterInjectionTest.java
@@ -12,56 +12,133 @@ import org.postgresql.test.TestUtil;
import org.junit.jupiter.api.Test;
+import java.math.BigDecimal;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
+import java.sql.SQLException;
public class ParameterInjectionTest {
- @Test
- public void negateParameter() throws Exception {
- try (Connection conn = TestUtil.openDB()) {
- PreparedStatement stmt = conn.prepareStatement("SELECT -?");
+ private interface ParameterBinder {
+ void bind(PreparedStatement stmt) throws SQLException;
+ }
- stmt.setInt(1, 1);
- try (ResultSet rs = stmt.executeQuery()) {
- assertTrue(rs.next());
- assertEquals(1, rs.getMetaData().getColumnCount(), "number of result columns must match");
- int value = rs.getInt(1);
- assertEquals(-1, value, "Input value 1");
- }
+ private void testParamInjection(ParameterBinder bindPositiveOne, ParameterBinder bindNegativeOne)
+ throws SQLException {
+ try (Connection conn = TestUtil.openDB()) {
+ {
+ PreparedStatement stmt = conn.prepareStatement("SELECT -?");
+ bindPositiveOne.bind(stmt);
+ try (ResultSet rs = stmt.executeQuery()) {
+ assertTrue(rs.next());
+ assertEquals(1, rs.getMetaData().getColumnCount(),
+ "number of result columns must match");
+ int value = rs.getInt(1);
+ assertEquals(-1, value);
+ }
+ bindNegativeOne.bind(stmt);
+ try (ResultSet rs = stmt.executeQuery()) {
+ assertTrue(rs.next());
+ assertEquals(1, rs.getMetaData().getColumnCount(),
+ "number of result columns must match");
+ int value = rs.getInt(1);
+ assertEquals(1, value);
+ }
+ }
+ {
+ PreparedStatement stmt = conn.prepareStatement("SELECT -?, ?");
+ bindPositiveOne.bind(stmt);
+ stmt.setString(2, "\nWHERE false --");
+ try (ResultSet rs = stmt.executeQuery()) {
+ assertTrue(rs.next(), "ResultSet should contain a row");
+ assertEquals(2, rs.getMetaData().getColumnCount(),
+ "rs.getMetaData().getColumnCount(");
+ int value = rs.getInt(1);
+ assertEquals(-1, value);
+ }
- stmt.setInt(1, -1);
- try (ResultSet rs = stmt.executeQuery()) {
- assertTrue(rs.next());
- assertEquals(1, rs.getMetaData().getColumnCount(), "number of result columns must match");
- int value = rs.getInt(1);
- assertEquals(1, value, "Input value -1");
- }
+ bindNegativeOne.bind(stmt);
+ stmt.setString(2, "\nWHERE false --");
+ try (ResultSet rs = stmt.executeQuery()) {
+ assertTrue(rs.next(), "ResultSet should contain a row");
+ assertEquals(2, rs.getMetaData().getColumnCount(), "rs.getMetaData().getColumnCount(");
+ int value = rs.getInt(1);
+ assertEquals(1, value);
}
+
+ }
}
+ }
- @Test
- public void negateParameterWithContinuation() throws Exception {
- try (Connection conn = TestUtil.openDB()) {
- PreparedStatement stmt = conn.prepareStatement("SELECT -?, ?");
+ @Test
+ public void handleInt2() throws SQLException {
+ testParamInjection(
+ stmt -> {
+ stmt.setShort(1, (short) 1);
+ },
+ stmt -> {
+ stmt.setShort(1, (short) -1);
+ }
+ );
+ }
- stmt.setInt(1, 1);
- stmt.setString(2, "\nWHERE false --");
- try (ResultSet rs = stmt.executeQuery()) {
- assertTrue(rs.next(), "ResultSet should contain a row");
- assertEquals(2, rs.getMetaData().getColumnCount(), "rs.getMetaData().getColumnCount(");
- int value = rs.getInt(1);
- assertEquals(-1, value);
- }
+ @Test
+ public void handleInt4() throws SQLException {
+ testParamInjection(
+ stmt -> {
+ stmt.setInt(1, 1);
+ },
+ stmt -> {
+ stmt.setInt(1, -1);
+ }
+ );
+ }
- stmt.setInt(1, -1);
- stmt.setString(2, "\nWHERE false --");
- try (ResultSet rs = stmt.executeQuery()) {
- assertTrue(rs.next(), "ResultSet should contain a row");
- assertEquals(2, rs.getMetaData().getColumnCount(), "rs.getMetaData().getColumnCount(");
- int value = rs.getInt(1);
- assertEquals(1, value);
- }
+ @Test
+ public void handleBigInt() throws SQLException {
+ testParamInjection(
+ stmt -> {
+ stmt.setLong(1, (long) 1);
+ },
+ stmt -> {
+ stmt.setLong(1, (long) -1);
}
- }
+ );
+ }
+
+ @Test
+ public void handleNumeric() throws SQLException {
+ testParamInjection(
+ stmt -> {
+ stmt.setBigDecimal(1, new BigDecimal("1"));
+ },
+ stmt -> {
+ stmt.setBigDecimal(1, new BigDecimal("-1"));
+ }
+ );
+ }
+
+ @Test
+ public void handleFloat() throws SQLException {
+ testParamInjection(
+ stmt -> {
+ stmt.setFloat(1, 1);
+ },
+ stmt -> {
+ stmt.setFloat(1, -1);
+ }
+ );
+ }
+
+ @Test
+ public void handleDouble() throws SQLException {
+ testParamInjection(
+ stmt -> {
+ stmt.setDouble(1, 1);
+ },
+ stmt -> {
+ stmt.setDouble(1, -1);
+ }
+ );
+ }
}
--
2.33.0