From e479be16967ae20af1bd57187738d5243052bbee Mon Sep 17 00:00:00 2001 From: Dave Cramer Date: Mon, 19 Feb 2024 08:20:59 -0500 Subject: [PATCH] Merge pull request from GHSA-24rp-q3w6-vc56 * SQL Injection via line comment generation for 42_4_x * fix: Add parentheses around NULL parameter values in simple query mode --------- Co-authored-by: Sehrope Sarkuni --- .../core/v3/SimpleParameterList.java | 112 +++++++++++------- .../core/v3/V3ParameterListTests.java | 6 +- .../jdbc/ParameterInjectionTest.java | 67 +++++++++++ 3 files changed, 142 insertions(+), 43 deletions(-) create mode 100644 src/test/java/org/postgresql/jdbc/ParameterInjectionTest.java diff --git a/src/main/java/org/postgresql/core/v3/SimpleParameterList.java b/src/main/java/org/postgresql/core/v3/SimpleParameterList.java index 969880b..b996e0b 100644 --- a/src/main/java/org/postgresql/core/v3/SimpleParameterList.java +++ b/src/main/java/org/postgresql/core/v3/SimpleParameterList.java @@ -172,6 +172,58 @@ class SimpleParameterList implements V3ParameterList { } bind(index, NULL_OBJECT, oid, binaryTransfer); } + + /** + *

Escapes a given text value as a literal, wraps it in single quotes, casts it to the + * to the given data type, and finally wraps the whole thing in parentheses.

+ * + *

For example, "123" and "int4" becomes "('123'::int)"

+ * + *

The additional parentheses is added to ensure that the surrounding text of where the + * parameter value is entered does modify the interpretation of the value.

+ * + *

For example if our input SQL is: SELECT ?b

+ * + *

Using a parameter value of '{}' and type of json we'd get:

+ * + *
+   * test=# SELECT ('{}'::json)b;
+   *  b
+   * ----
+   *  {}
+   * 
+ * + *

But without the parentheses the result changes:

+ * + *
+   * test=# SELECT '{}'::jsonb;
+   * jsonb
+   * -------
+   * {}
+   * 
+ **/ + private static String quoteAndCast(String text, String type, boolean standardConformingStrings) { + StringBuilder sb = new StringBuilder((text.length() + 10) / 10 * 11); // Add 10% for escaping. + sb.append("('"); + try { + Utils.escapeLiteral(sb, text, standardConformingStrings); + } catch (SQLException e) { + // This should only happen if we have an embedded null + // and there's not much we can do if we do hit one. + // + // To force a server side failure, we deliberately include + // a zero byte character in the literal to force the server + // to reject the command. + sb.append('\u0000'); + } + sb.append("'"); + if (type != null) { + sb.append("::"); + sb.append(type); + } + sb.append(")"); + return sb.toString(); + } @Override public String toString(/* @Positive */ int index, boolean standardConformingStrings) { @@ -180,100 +232,80 @@ class SimpleParameterList implements V3ParameterList { if (paramValue == null) { return "?"; } else if (paramValue == NULL_OBJECT) { - return "NULL"; + return "(NULL)"; } else if ((flags[index] & BINARY) == BINARY) { // handle some of the numeric types switch (paramTypes[index]) { case Oid.INT2: short s = ByteConverter.int2((byte[]) paramValue, 0); - return Short.toString(s); + return quoteAndCast(Short.toString(s), "int2", standardConformingStrings); case Oid.INT4: int i = ByteConverter.int4((byte[]) paramValue, 0); - return Integer.toString(i); + return quoteAndCast(Integer.toString(i), "int4", standardConformingStrings); case Oid.INT8: long l = ByteConverter.int8((byte[]) paramValue, 0); - return Long.toString(l); + return quoteAndCast(Long.toString(l), "int8", standardConformingStrings); case Oid.FLOAT4: float f = ByteConverter.float4((byte[]) paramValue, 0); if (Float.isNaN(f)) { - return "'NaN'::real"; + return "('NaN'::real)"; } - return Float.toString(f); + return quoteAndCast(Float.toString(f), "float", standardConformingStrings); case Oid.FLOAT8: double d = ByteConverter.float8((byte[]) paramValue, 0); if (Double.isNaN(d)) { - return "'NaN'::double precision"; + return "('NaN'::double precision)"; } - return Double.toString(d); + return quoteAndCast(Double.toString(d), "double precision", standardConformingStrings); case Oid.NUMERIC: Number n = ByteConverter.numeric((byte[]) paramValue); if (n instanceof Double) { assert ((Double) n).isNaN(); - return "'NaN'::numeric"; + return "('NaN'::numeric)"; } return n.toString(); case Oid.UUID: String uuid = new UUIDArrayAssistant().buildElement((byte[]) paramValue, 0, 16).toString(); - return "'" + uuid + "'::uuid"; + return quoteAndCast(uuid, "uuid", standardConformingStrings); case Oid.POINT: PGpoint pgPoint = new PGpoint(); pgPoint.setByteValue((byte[]) paramValue, 0); - return "'" + pgPoint.toString() + "'::point"; + return quoteAndCast(pgPoint.toString(), "point", standardConformingStrings); case Oid.BOX: PGbox pgBox = new PGbox(); pgBox.setByteValue((byte[]) paramValue, 0); - return "'" + pgBox.toString() + "'::box"; + return quoteAndCast(pgBox.toString(), "box", standardConformingStrings); } return "?"; } else { String param = paramValue.toString(); - - // add room for quotes + potential escaping. - StringBuilder p = new StringBuilder(3 + (param.length() + 10) / 10 * 11); - - // No E'..' here since escapeLiteral escapes all things and it does not use \123 kind of - // escape codes - p.append('\''); - try { - p = Utils.escapeLiteral(p, param, standardConformingStrings); - } catch (SQLException sqle) { - // This should only happen if we have an embedded null - // and there's not much we can do if we do hit one. - // - // The goal of toString isn't to be sent to the server, - // so we aren't 100% accurate (see StreamWrapper), put - // the unescaped version of the data. - // - p.append(param); - } - p.append('\''); int paramType = paramTypes[index]; if (paramType == Oid.TIMESTAMP) { - p.append("::timestamp"); + return quoteAndCast(param, "timestamp", standardConformingStrings); } else if (paramType == Oid.TIMESTAMPTZ) { - p.append("::timestamp with time zone"); + return quoteAndCast(param, "timestamp with time zone", standardConformingStrings); } else if (paramType == Oid.TIME) { - p.append("::time"); + return quoteAndCast(param, "time", standardConformingStrings); } else if (paramType == Oid.TIMETZ) { - p.append("::time with time zone"); + return quoteAndCast(param, "time with time zone", standardConformingStrings); } else if (paramType == Oid.DATE) { - p.append("::date"); + return quoteAndCast(param, "date", standardConformingStrings); } else if (paramType == Oid.INTERVAL) { - p.append("::interval"); + return quoteAndCast(param, "interval", standardConformingStrings); } else if (paramType == Oid.NUMERIC) { - p.append("::numeric"); + return quoteAndCast(param, "numeric", standardConformingStrings); } - return p.toString(); + return quoteAndCast(param, null, standardConformingStrings); } } diff --git a/src/test/java/org/postgresql/core/v3/V3ParameterListTests.java b/src/test/java/org/postgresql/core/v3/V3ParameterListTests.java index 200004f..19c51cb 100644 --- a/src/test/java/org/postgresql/core/v3/V3ParameterListTests.java +++ b/src/test/java/org/postgresql/core/v3/V3ParameterListTests.java @@ -58,8 +58,8 @@ public class V3ParameterListTests { s2SPL.setIntParameter(4, 8); s1SPL.appendAll(s2SPL); - assertEquals( - "Expected string representation of values does not match outcome.", - "<[1 ,2 ,3 ,4 ,5 ,6 ,7 ,8]>", s1SPL.toString()); + assertEquals("Expected string representation of values does not match outcome.", + "<[('1'::int4) ,('2'::int4) ,('3'::int4) ,('4'::int4) ,('5'::int4) ,('6'::int4) ,('7'::int4) ,('8'::int4)]>", s1SPL.toString()); + } } diff --git a/src/test/java/org/postgresql/jdbc/ParameterInjectionTest.java b/src/test/java/org/postgresql/jdbc/ParameterInjectionTest.java new file mode 100644 index 0000000..2a33acd --- /dev/null +++ b/src/test/java/org/postgresql/jdbc/ParameterInjectionTest.java @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2024, PostgreSQL Global Development Group + * See the LICENSE file in the project root for more information. + */ + +package org.postgresql.jdbc; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.postgresql.test.TestUtil; + +import org.junit.jupiter.api.Test; + +import java.sql.Connection; +import java.sql.PreparedStatement; +import java.sql.ResultSet; + +public class ParameterInjectionTest { + @Test + public void negateParameter() throws Exception { + try (Connection conn = TestUtil.openDB()) { + PreparedStatement stmt = conn.prepareStatement("SELECT -?"); + + stmt.setInt(1, 1); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(1, rs.getMetaData().getColumnCount(), "number of result columns must match"); + int value = rs.getInt(1); + assertEquals(-1, value, "Input value 1"); + } + + stmt.setInt(1, -1); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next()); + assertEquals(1, rs.getMetaData().getColumnCount(), "number of result columns must match"); + int value = rs.getInt(1); + assertEquals(1, value, "Input value -1"); + } + } + } + + @Test + public void negateParameterWithContinuation() throws Exception { + try (Connection conn = TestUtil.openDB()) { + PreparedStatement stmt = conn.prepareStatement("SELECT -?, ?"); + + stmt.setInt(1, 1); + stmt.setString(2, "\nWHERE false --"); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "ResultSet should contain a row"); + assertEquals(2, rs.getMetaData().getColumnCount(), "rs.getMetaData().getColumnCount("); + int value = rs.getInt(1); + assertEquals(-1, value); + } + + stmt.setInt(1, -1); + stmt.setString(2, "\nWHERE false --"); + try (ResultSet rs = stmt.executeQuery()) { + assertTrue(rs.next(), "ResultSet should contain a row"); + assertEquals(2, rs.getMetaData().getColumnCount(), "rs.getMetaData().getColumnCount("); + int value = rs.getInt(1); + assertEquals(1, value); + } + } + } +} -- 2.33.0