logback/CVE-2017-5929.patch
2020-09-19 16:22:54 +08:00

304 lines
11 KiB
Diff

From f46044b805bca91efe5fd6afe52257cd02f775f8 Mon Sep 17 00:00:00 2001
From: Ceki Gulcu <ceki@qos.ch>
Date: Tue, 7 Feb 2017 23:12:51 +0100
Subject: [PATCH] harden serialization
---
.../classic/net/SimpleSocketServer.java | 1 -
.../LogbackClassicSerializationHelper.java | 28 ++++++++
.../classic/LoggerSerializationTest.java | 10 ++-
.../core/net/HardenedObjectInputStream.java | 48 +++++++++++++
.../net/HardenedObjectInputStreamTest.java | 61 ++++++++++++++++
.../ch/qos/logback/core/net/Innocent.java | 69 +++++++++++++++++++
6 files changed, 214 insertions(+), 3 deletions(-)
create mode 100755 logback-classic/src/main/java/ch/qos/logback/classic/net/server/LogbackClassicSerializationHelper.java
create mode 100755 logback-core/src/main/java/ch/qos/logback/core/net/HardenedObjectInputStream.java
create mode 100755 logback-core/src/test/java/ch/qos/logback/core/net/HardenedObjectInputStreamTest.java
create mode 100755 logback-core/src/test/java/ch/qos/logback/core/net/Innocent.java
diff --git a/logback-classic/src/main/java/ch/qos/logback/classic/net/SimpleSocketServer.java b/logback-classic/src/main/java/ch/qos/logback/classic/net/SimpleSocketServer.java
index 6d39a2473..3083f45ce 100755
--- a/logback-classic/src/main/java/ch/qos/logback/classic/net/SimpleSocketServer.java
+++ b/logback-classic/src/main/java/ch/qos/logback/classic/net/SimpleSocketServer.java
@@ -14,7 +14,6 @@
package ch.qos.logback.classic.net;
import java.io.IOException;
-import java.lang.reflect.Constructor;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.ArrayList;
diff --git a/logback-classic/src/main/java/ch/qos/logback/classic/net/server/LogbackClassicSerializationHelper.java b/logback-classic/src/main/java/ch/qos/logback/classic/net/server/LogbackClassicSerializationHelper.java
new file mode 100755
index 000000000..00a974f81
--- /dev/null
+++ b/logback-classic/src/main/java/ch/qos/logback/classic/net/server/LogbackClassicSerializationHelper.java
@@ -0,0 +1,28 @@
+package ch.qos.logback.classic.net.server;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import org.slf4j.helpers.BasicMarker;
+
+import ch.qos.logback.classic.Logger;
+import ch.qos.logback.classic.spi.LoggerContextVO;
+import ch.qos.logback.classic.spi.LoggingEventVO;
+import ch.qos.logback.classic.spi.ThrowableProxyVO;
+
+public class LogbackClassicSerializationHelper {
+
+
+
+ static public List<String> getWhilelist() {
+ List<String> whitelist = new ArrayList<String>();
+ whitelist.add(LoggingEventVO.class.getName());
+ whitelist.add(LoggerContextVO.class.getName());
+ whitelist.add(ThrowableProxyVO.class.getName());
+ whitelist.add(StackTraceElement.class.getName());
+ whitelist.add(BasicMarker.class.getName());
+ whitelist.add(BasicMarker.class.getName());
+ whitelist.add(Logger.class.getName());
+ return whitelist;
+ }
+}
diff --git a/logback-classic/src/test/java/ch/qos/logback/classic/LoggerSerializationTest.java b/logback-classic/src/test/java/ch/qos/logback/classic/LoggerSerializationTest.java
index ec6cb01d7..618d1756e 100644
--- a/logback-classic/src/test/java/ch/qos/logback/classic/LoggerSerializationTest.java
+++ b/logback-classic/src/test/java/ch/qos/logback/classic/LoggerSerializationTest.java
@@ -14,7 +14,10 @@
package ch.qos.logback.classic;
import java.io.*;
+import java.util.List;
+import ch.qos.logback.classic.net.server.LogbackClassicSerializationHelper;
+import ch.qos.logback.core.net.HardenedObjectInputStream;
import ch.qos.logback.core.util.CoreTestConstants;
import org.junit.After;
import org.junit.Before;
@@ -36,7 +39,8 @@
ByteArrayOutputStream bos;
ObjectOutputStream oos;
ObjectInputStream inputStream;
-
+ List<String> whitelist ;
+
@Before
public void setUp() throws Exception {
lc = new LoggerContext();
@@ -45,6 +49,8 @@ public void setUp() throws Exception {
// create the byte output stream
bos = new ByteArrayOutputStream();
oos = new ObjectOutputStream(bos);
+ whitelist = LogbackClassicSerializationHelper.getWhilelist();
+ whitelist.add(Foo.class.getName());
}
@After
@@ -110,7 +116,7 @@ public void deepTreeSerialization() throws IOException {
private Foo writeAndRead(Foo foo) throws IOException, ClassNotFoundException {
writeObject(oos, foo);
ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
- inputStream = new ObjectInputStream(bis);
+ inputStream = new HardenedObjectInputStream(bis, whitelist);
Foo fooBack = readFooObject(inputStream);
inputStream.close();
return fooBack;
diff --git a/logback-core/src/main/java/ch/qos/logback/core/net/HardenedObjectInputStream.java b/logback-core/src/main/java/ch/qos/logback/core/net/HardenedObjectInputStream.java
new file mode 100755
index 000000000..439e2bde5
--- /dev/null
+++ b/logback-core/src/main/java/ch/qos/logback/core/net/HardenedObjectInputStream.java
@@ -0,0 +1,48 @@
+package ch.qos.logback.core.net;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InvalidClassException;
+import java.io.ObjectInputStream;
+import java.io.ObjectStreamClass;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
+/**
+ *
+ * @author Ceki G&uuml;lc&uuml;
+ * @since 1.2.0
+ */
+public class HardenedObjectInputStream extends ObjectInputStream {
+
+ List<String> whitelistedClassNames;
+ String[] javaPackages = new String[] {"java.lang", "java.util"};
+
+ public HardenedObjectInputStream(InputStream in, List<String> whilelist) throws IOException {
+ super(in);
+ this.whitelistedClassNames = Collections.synchronizedList(new ArrayList<String>(whilelist));
+ }
+
+ @Override
+ protected Class<?> resolveClass(ObjectStreamClass anObjectStreamClass) throws IOException, ClassNotFoundException {
+ String incomingClassName = anObjectStreamClass.getName();
+ if(!isWhitelisted(incomingClassName)) {
+ throw new InvalidClassException("Unauthorized deserialization attempt", anObjectStreamClass.getName());
+ }
+
+ return super.resolveClass(anObjectStreamClass);
+ }
+
+ private boolean isWhitelisted(String incomingClassName) {
+ for(int i = 0; i < javaPackages.length; i++) {
+ if(incomingClassName.startsWith(javaPackages[i]))
+ return true;
+ }
+ for(String className: whitelistedClassNames) {
+ if(incomingClassName.equals(className))
+ return true;
+ }
+ return false;
+ }
+}
diff --git a/logback-core/src/test/java/ch/qos/logback/core/net/HardenedObjectInputStreamTest.java b/logback-core/src/test/java/ch/qos/logback/core/net/HardenedObjectInputStreamTest.java
new file mode 100755
index 000000000..6a3489755
--- /dev/null
+++ b/logback-core/src/test/java/ch/qos/logback/core/net/HardenedObjectInputStreamTest.java
@@ -0,0 +1,61 @@
+package ch.qos.logback.core.net;
+
+import static org.junit.Assert.*;
+
+import java.io.ByteArrayInputStream;
+import java.io.ByteArrayOutputStream;
+import java.io.IOException;
+import java.io.ObjectOutputStream;
+import java.util.ArrayList;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+public class HardenedObjectInputStreamTest {
+
+ ByteArrayOutputStream bos;
+ ObjectOutputStream oos;
+ HardenedObjectInputStream inputStream;
+ List<String> whitelist = new ArrayList<String>();
+
+ @Before
+ public void setUp() throws Exception {
+ whitelist.add(Innocent.class.getName());
+ bos = new ByteArrayOutputStream();
+ oos = new ObjectOutputStream(bos);
+ }
+
+ @After
+ public void tearDown() throws Exception {
+ }
+
+ @Test
+ public void smoke() throws ClassNotFoundException, IOException {
+ Innocent innocent = new Innocent();
+ innocent.setAnInt(1);
+ innocent.setAnInteger(2);
+ innocent.setaString("smoke");
+ Innocent back = writeAndRead(innocent);
+ assertEquals(innocent, back);
+ }
+
+
+
+ private Innocent writeAndRead(Innocent innocent) throws IOException, ClassNotFoundException {
+ writeObject(oos, innocent);
+ ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray());
+ inputStream = new HardenedObjectInputStream(bis, whitelist);
+ Innocent fooBack = (Innocent) inputStream.readObject();
+ inputStream.close();
+ return fooBack;
+ }
+
+ private void writeObject(ObjectOutputStream oos, Object o) throws IOException {
+ oos.writeObject(o);
+ oos.flush();
+ oos.close();
+ }
+
+}
diff --git a/logback-core/src/test/java/ch/qos/logback/core/net/Innocent.java b/logback-core/src/test/java/ch/qos/logback/core/net/Innocent.java
new file mode 100755
index 000000000..2cef5a08e
--- /dev/null
+++ b/logback-core/src/test/java/ch/qos/logback/core/net/Innocent.java
@@ -0,0 +1,69 @@
+package ch.qos.logback.core.net;
+
+public class Innocent implements java.io.Serializable {
+
+ private static final long serialVersionUID = -1227008349289885025L;
+
+ int anInt;
+ Integer anInteger;
+ String aString;
+
+ public int getAnInt() {
+ return anInt;
+ }
+
+ public void setAnInt(int anInt) {
+ this.anInt = anInt;
+ }
+
+ public Integer getAnInteger() {
+ return anInteger;
+ }
+
+ public void setAnInteger(Integer anInteger) {
+ this.anInteger = anInteger;
+ }
+
+ public String getaString() {
+ return aString;
+ }
+
+ public void setaString(String aString) {
+ this.aString = aString;
+ }
+
+ @Override
+ public int hashCode() {
+ final int prime = 31;
+ int result = 1;
+ result = prime * result + ((aString == null) ? 0 : aString.hashCode());
+ result = prime * result + anInt;
+ result = prime * result + ((anInteger == null) ? 0 : anInteger.hashCode());
+ return result;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (this == obj)
+ return true;
+ if (obj == null)
+ return false;
+ if (getClass() != obj.getClass())
+ return false;
+ Innocent other = (Innocent) obj;
+ if (aString == null) {
+ if (other.aString != null)
+ return false;
+ } else if (!aString.equals(other.aString))
+ return false;
+ if (anInt != other.anInt)
+ return false;
+ if (anInteger == null) {
+ if (other.anInteger != null)
+ return false;
+ } else if (!anInteger.equals(other.anInteger))
+ return false;
+ return true;
+ }
+
+}