From f46044b805bca91efe5fd6afe52257cd02f775f8 Mon Sep 17 00:00:00 2001 From: Ceki Gulcu Date: Tue, 7 Feb 2017 23:12:51 +0100 Subject: [PATCH] harden serialization --- .../classic/net/SimpleSocketServer.java | 1 - .../LogbackClassicSerializationHelper.java | 28 ++++++++ .../classic/LoggerSerializationTest.java | 10 ++- .../core/net/HardenedObjectInputStream.java | 48 +++++++++++++ .../net/HardenedObjectInputStreamTest.java | 61 ++++++++++++++++ .../ch/qos/logback/core/net/Innocent.java | 69 +++++++++++++++++++ 6 files changed, 214 insertions(+), 3 deletions(-) create mode 100755 logback-classic/src/main/java/ch/qos/logback/classic/net/server/LogbackClassicSerializationHelper.java create mode 100755 logback-core/src/main/java/ch/qos/logback/core/net/HardenedObjectInputStream.java create mode 100755 logback-core/src/test/java/ch/qos/logback/core/net/HardenedObjectInputStreamTest.java create mode 100755 logback-core/src/test/java/ch/qos/logback/core/net/Innocent.java diff --git a/logback-classic/src/main/java/ch/qos/logback/classic/net/SimpleSocketServer.java b/logback-classic/src/main/java/ch/qos/logback/classic/net/SimpleSocketServer.java index 6d39a2473..3083f45ce 100755 --- a/logback-classic/src/main/java/ch/qos/logback/classic/net/SimpleSocketServer.java +++ b/logback-classic/src/main/java/ch/qos/logback/classic/net/SimpleSocketServer.java @@ -14,7 +14,6 @@ package ch.qos.logback.classic.net; import java.io.IOException; -import java.lang.reflect.Constructor; import java.net.ServerSocket; import java.net.Socket; import java.util.ArrayList; diff --git a/logback-classic/src/main/java/ch/qos/logback/classic/net/server/LogbackClassicSerializationHelper.java b/logback-classic/src/main/java/ch/qos/logback/classic/net/server/LogbackClassicSerializationHelper.java new file mode 100755 index 000000000..00a974f81 --- /dev/null +++ b/logback-classic/src/main/java/ch/qos/logback/classic/net/server/LogbackClassicSerializationHelper.java @@ -0,0 +1,28 @@ +package ch.qos.logback.classic.net.server; + +import java.util.ArrayList; +import java.util.List; + +import org.slf4j.helpers.BasicMarker; + +import ch.qos.logback.classic.Logger; +import ch.qos.logback.classic.spi.LoggerContextVO; +import ch.qos.logback.classic.spi.LoggingEventVO; +import ch.qos.logback.classic.spi.ThrowableProxyVO; + +public class LogbackClassicSerializationHelper { + + + + static public List getWhilelist() { + List whitelist = new ArrayList(); + whitelist.add(LoggingEventVO.class.getName()); + whitelist.add(LoggerContextVO.class.getName()); + whitelist.add(ThrowableProxyVO.class.getName()); + whitelist.add(StackTraceElement.class.getName()); + whitelist.add(BasicMarker.class.getName()); + whitelist.add(BasicMarker.class.getName()); + whitelist.add(Logger.class.getName()); + return whitelist; + } +} diff --git a/logback-classic/src/test/java/ch/qos/logback/classic/LoggerSerializationTest.java b/logback-classic/src/test/java/ch/qos/logback/classic/LoggerSerializationTest.java index ec6cb01d7..618d1756e 100644 --- a/logback-classic/src/test/java/ch/qos/logback/classic/LoggerSerializationTest.java +++ b/logback-classic/src/test/java/ch/qos/logback/classic/LoggerSerializationTest.java @@ -14,7 +14,10 @@ package ch.qos.logback.classic; import java.io.*; +import java.util.List; +import ch.qos.logback.classic.net.server.LogbackClassicSerializationHelper; +import ch.qos.logback.core.net.HardenedObjectInputStream; import ch.qos.logback.core.util.CoreTestConstants; import org.junit.After; import org.junit.Before; @@ -36,7 +39,8 @@ ByteArrayOutputStream bos; ObjectOutputStream oos; ObjectInputStream inputStream; - + List whitelist ; + @Before public void setUp() throws Exception { lc = new LoggerContext(); @@ -45,6 +49,8 @@ public void setUp() throws Exception { // create the byte output stream bos = new ByteArrayOutputStream(); oos = new ObjectOutputStream(bos); + whitelist = LogbackClassicSerializationHelper.getWhilelist(); + whitelist.add(Foo.class.getName()); } @After @@ -110,7 +116,7 @@ public void deepTreeSerialization() throws IOException { private Foo writeAndRead(Foo foo) throws IOException, ClassNotFoundException { writeObject(oos, foo); ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray()); - inputStream = new ObjectInputStream(bis); + inputStream = new HardenedObjectInputStream(bis, whitelist); Foo fooBack = readFooObject(inputStream); inputStream.close(); return fooBack; diff --git a/logback-core/src/main/java/ch/qos/logback/core/net/HardenedObjectInputStream.java b/logback-core/src/main/java/ch/qos/logback/core/net/HardenedObjectInputStream.java new file mode 100755 index 000000000..439e2bde5 --- /dev/null +++ b/logback-core/src/main/java/ch/qos/logback/core/net/HardenedObjectInputStream.java @@ -0,0 +1,48 @@ +package ch.qos.logback.core.net; + +import java.io.IOException; +import java.io.InputStream; +import java.io.InvalidClassException; +import java.io.ObjectInputStream; +import java.io.ObjectStreamClass; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * + * @author Ceki Gülcü + * @since 1.2.0 + */ +public class HardenedObjectInputStream extends ObjectInputStream { + + List whitelistedClassNames; + String[] javaPackages = new String[] {"java.lang", "java.util"}; + + public HardenedObjectInputStream(InputStream in, List whilelist) throws IOException { + super(in); + this.whitelistedClassNames = Collections.synchronizedList(new ArrayList(whilelist)); + } + + @Override + protected Class resolveClass(ObjectStreamClass anObjectStreamClass) throws IOException, ClassNotFoundException { + String incomingClassName = anObjectStreamClass.getName(); + if(!isWhitelisted(incomingClassName)) { + throw new InvalidClassException("Unauthorized deserialization attempt", anObjectStreamClass.getName()); + } + + return super.resolveClass(anObjectStreamClass); + } + + private boolean isWhitelisted(String incomingClassName) { + for(int i = 0; i < javaPackages.length; i++) { + if(incomingClassName.startsWith(javaPackages[i])) + return true; + } + for(String className: whitelistedClassNames) { + if(incomingClassName.equals(className)) + return true; + } + return false; + } +} diff --git a/logback-core/src/test/java/ch/qos/logback/core/net/HardenedObjectInputStreamTest.java b/logback-core/src/test/java/ch/qos/logback/core/net/HardenedObjectInputStreamTest.java new file mode 100755 index 000000000..6a3489755 --- /dev/null +++ b/logback-core/src/test/java/ch/qos/logback/core/net/HardenedObjectInputStreamTest.java @@ -0,0 +1,61 @@ +package ch.qos.logback.core.net; + +import static org.junit.Assert.*; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.ObjectOutputStream; +import java.util.ArrayList; +import java.util.List; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; + +public class HardenedObjectInputStreamTest { + + ByteArrayOutputStream bos; + ObjectOutputStream oos; + HardenedObjectInputStream inputStream; + List whitelist = new ArrayList(); + + @Before + public void setUp() throws Exception { + whitelist.add(Innocent.class.getName()); + bos = new ByteArrayOutputStream(); + oos = new ObjectOutputStream(bos); + } + + @After + public void tearDown() throws Exception { + } + + @Test + public void smoke() throws ClassNotFoundException, IOException { + Innocent innocent = new Innocent(); + innocent.setAnInt(1); + innocent.setAnInteger(2); + innocent.setaString("smoke"); + Innocent back = writeAndRead(innocent); + assertEquals(innocent, back); + } + + + + private Innocent writeAndRead(Innocent innocent) throws IOException, ClassNotFoundException { + writeObject(oos, innocent); + ByteArrayInputStream bis = new ByteArrayInputStream(bos.toByteArray()); + inputStream = new HardenedObjectInputStream(bis, whitelist); + Innocent fooBack = (Innocent) inputStream.readObject(); + inputStream.close(); + return fooBack; + } + + private void writeObject(ObjectOutputStream oos, Object o) throws IOException { + oos.writeObject(o); + oos.flush(); + oos.close(); + } + +} diff --git a/logback-core/src/test/java/ch/qos/logback/core/net/Innocent.java b/logback-core/src/test/java/ch/qos/logback/core/net/Innocent.java new file mode 100755 index 000000000..2cef5a08e --- /dev/null +++ b/logback-core/src/test/java/ch/qos/logback/core/net/Innocent.java @@ -0,0 +1,69 @@ +package ch.qos.logback.core.net; + +public class Innocent implements java.io.Serializable { + + private static final long serialVersionUID = -1227008349289885025L; + + int anInt; + Integer anInteger; + String aString; + + public int getAnInt() { + return anInt; + } + + public void setAnInt(int anInt) { + this.anInt = anInt; + } + + public Integer getAnInteger() { + return anInteger; + } + + public void setAnInteger(Integer anInteger) { + this.anInteger = anInteger; + } + + public String getaString() { + return aString; + } + + public void setaString(String aString) { + this.aString = aString; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((aString == null) ? 0 : aString.hashCode()); + result = prime * result + anInt; + result = prime * result + ((anInteger == null) ? 0 : anInteger.hashCode()); + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) + return true; + if (obj == null) + return false; + if (getClass() != obj.getClass()) + return false; + Innocent other = (Innocent) obj; + if (aString == null) { + if (other.aString != null) + return false; + } else if (!aString.equals(other.aString)) + return false; + if (anInt != other.anInt) + return false; + if (anInteger == null) { + if (other.anInteger != null) + return false; + } else if (!anInteger.equals(other.anInteger)) + return false; + return true; + } + +}