From f9b2f236e6a663011b50bd7e9a41ec90e6b94831 Mon Sep 17 00:00:00 2001 From: Lyor Goldstein Date: Thu, 25 Feb 2021 21:05:49 +0200 Subject: [PATCH] [SSHD-1125] Added mechanism to throttle pending write requests in BufferedIoOutputStream --- .../channel/BufferedIoOutputStream.java | 187 ++++++++++++++++-- .../SshChannelBufferedOutputException.java | 41 ++++ .../sshd/core/CoreModuleProperties.java | 19 ++ .../sshd/util/test/AsyncEchoShellFactory.java | 13 +- .../server/subsystem/sftp/SftpSubsystem.java | 3 +- 5 files changed, 243 insertions(+), 20 deletions(-) create mode 100644 sshd-core/src/main/java/org/apache/sshd/common/channel/exception/SshChannelBufferedOutputException.java diff --git a/sshd-core/src/main/java/org/apache/sshd/common/channel/BufferedIoOutputStream.java b/sshd-core/src/main/java/org/apache/sshd/common/channel/BufferedIoOutputStream.java index 1cb75aa..c95a449 100644 --- a/sshd-core/src/main/java/org/apache/sshd/common/channel/BufferedIoOutputStream.java +++ b/sshd-core/src/main/java/org/apache/sshd/common/channel/BufferedIoOutputStream.java @@ -20,29 +20,55 @@ package org.apache.sshd.common.channel; import java.io.EOFException; import java.io.IOException; +import java.time.Duration; +import java.util.Objects; import java.util.Queue; import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; import java.util.concurrent.atomic.AtomicReference; import org.apache.sshd.common.Closeable; +import org.apache.sshd.common.PropertyResolver; +import org.apache.sshd.common.channel.exception.SshChannelBufferedOutputException; import org.apache.sshd.common.future.SshFutureListener; import org.apache.sshd.common.io.IoOutputStream; import org.apache.sshd.common.io.IoWriteFuture; +import org.apache.sshd.common.util.GenericUtils; +import org.apache.sshd.common.util.ValidateUtils; import org.apache.sshd.common.util.buffer.Buffer; import org.apache.sshd.common.util.closeable.AbstractInnerCloseable; +import org.apache.sshd.core.CoreModuleProperties; /** * An {@link IoOutputStream} capable of queuing write requests */ public class BufferedIoOutputStream extends AbstractInnerCloseable implements IoOutputStream { + protected final Object id; + protected final int channelId; + protected final int maxPendingBytesCount; + protected final Duration maxWaitForPendingWrites; protected final IoOutputStream out; + protected final AtomicInteger pendingBytesCount = new AtomicInteger(); + protected final AtomicLong writtenBytesCount = new AtomicLong(); protected final Queue writes = new ConcurrentLinkedQueue<>(); protected final AtomicReference currentWrite = new AtomicReference<>(); - protected final Object id; + protected final AtomicReference pendingException = new AtomicReference<>(); - public BufferedIoOutputStream(Object id, IoOutputStream out) { - this.out = out; - this.id = id; + public BufferedIoOutputStream(Object id, int channelId, IoOutputStream out, PropertyResolver resolver) { + this(id, channelId, out, CoreModuleProperties.BUFFERED_IO_OUTPUT_MAX_PENDING_WRITE_SIZE.getRequired(resolver), + CoreModuleProperties.BUFFERED_IO_OUTPUT_MAX_PENDING_WRITE_WAIT.getRequired(resolver)); + } + + public BufferedIoOutputStream( + Object id, int channelId, IoOutputStream out, int maxPendingBytesCount, + Duration maxWaitForPendingWrites) { + this.id = Objects.requireNonNull(id, "No stream identifier provided"); + this.channelId = channelId; + this.out = Objects.requireNonNull(out, "No delegate output stream provided"); + this.maxPendingBytesCount = maxPendingBytesCount; + ValidateUtils.checkTrue(maxPendingBytesCount > 0, "Invalid max. pending bytes count: %d", maxPendingBytesCount); + this.maxWaitForPendingWrites = Objects.requireNonNull(maxWaitForPendingWrites, "No max. pending time value provided"); } public Object getId() { @@ -52,26 +78,114 @@ public class BufferedIoOutputStream extends AbstractInnerCloseable implements Io @Override public IoWriteFuture writePacket(Buffer buffer) throws IOException { if (isClosing()) { - throw new EOFException("Closed"); + throw new EOFException("Closed/ing - state=" + state); } + waitForAvailableWriteSpace(buffer.available()); + IoWriteFutureImpl future = new IoWriteFutureImpl(getId(), buffer); writes.add(future); startWriting(); return future; } + protected void waitForAvailableWriteSpace(int requiredSize) throws IOException { + /* + * NOTE: this code allows a single pending write to give this mechanism "the slip" and + * exit the loop "unscathed" even though there is a pending exception. However, the goal + * here is to avoid an OOM by having an unlimited accumulation of pending write requests + * due to fact that the peer is not consuming the sent data. Please note that the pending + * exception is "sticky" - i.e., the next write attempt will fail. This also means that if + * the write request that "got away" was the last one by chance and it was consumed by the + * peer there will be no exception thrown - which is also fine since as mentioned the goal + * is not to enforce a strict limit on the pending bytes size but rather on the accumulation + * of the pending write requests. + * + * We could have counted pending requests rather than bytes. However, we also want to avoid + * having a large amount of data pending consumption by the peer as well. This code strikes + * such a balance by allowing a single pending request to exceed the limit, but at the same + * time prevents too many bytes from pending by having a bunch of pending requests that while + * below the imposed number limit may cumulatively represent a lot of pending bytes. + */ + + long expireTime = System.currentTimeMillis() + maxWaitForPendingWrites.toMillis(); + synchronized (pendingBytesCount) { + for (int count = pendingBytesCount.get(); + /* + * The (count > 0) condition is put in place to allow a single pending + * write to exceed the maxPendingBytesCount as long as there are no + * other pending ones. + */ + (count > 0) + // Not already over the limit or about to be over it + && ((count + requiredSize) > maxPendingBytesCount) + // No pending exception signaled + && (pendingException.get() == null); + count = pendingBytesCount.get()) { + long remTime = expireTime - System.currentTimeMillis(); + if (remTime <= 0L) { + pendingException.compareAndSet(null, + new SshChannelBufferedOutputException( + channelId, + "Max. pending write timeout expired after " + writtenBytesCount + " bytes")); + throw pendingException.get(); + } + + try { + pendingBytesCount.wait(remTime); + } catch (InterruptedException e) { + pendingException.compareAndSet(null, + new SshChannelBufferedOutputException( + channelId, + "Waiting for pending writes interrupted after " + writtenBytesCount + " bytes")); + throw pendingException.get(); + } + } + + IOException e = pendingException.get(); + if (e != null) { + throw e; + } + + pendingBytesCount.addAndGet(requiredSize); + } + } + protected void startWriting() throws IOException { IoWriteFutureImpl future = writes.peek(); + // No more pending requests if (future == null) { return; } + // Don't try to write any further if pending exception signaled + Throwable pendingError = pendingException.get(); + if (pendingError != null) { + log.error("startWriting({})[{}] propagate to {} write requests pending error={}[{}]", + getId(), out, writes.size(), getClass().getSimpleName(), pendingError.getMessage()); + + IoWriteFutureImpl currentFuture = currentWrite.getAndSet(null); + for (IoWriteFutureImpl pendingWrite : writes) { + // Checking reference by design + if (GenericUtils.isSameReference(pendingWrite, currentFuture)) { + continue; // will be taken care of when its listener is eventually called + } + + future.setValue(pendingError); + } + + writes.clear(); + return; + } + + // Cannot honor this request yet since other pending one incomplete if (!currentWrite.compareAndSet(null, future)) { return; } - out.writePacket(future.getBuffer()).addListener(new SshFutureListener() { + Buffer buffer = future.getBuffer(); + int bufferSize = buffer.available(); + out.writePacket(buffer).addListener(new SshFutureListener() { @Override public void operationComplete(IoWriteFuture f) { if (f.isWritten()) { @@ -79,32 +193,71 @@ public class BufferedIoOutputStream extends AbstractInnerCloseable implements Io } else { future.setValue(f.getException()); } - finishWrite(); + finishWrite(future, bufferSize); + } + }); + } + + protected void finishWrite(IoWriteFutureImpl future, int bufferSize) { + /* + * Update the pending bytes count only if successfully written, + * otherwise signal an error + */ + if (future.isWritten()) { + long writtenSize = writtenBytesCount.addAndGet(bufferSize); + int stillPending; + synchronized (pendingBytesCount) { + stillPending = pendingBytesCount.addAndGet(0 - bufferSize); + pendingBytesCount.notifyAll(); } - @SuppressWarnings("synthetic-access") - private void finishWrite() { + /* + * NOTE: since the pending exception is updated outside the synchronized block + * a pending write could be successfully enqueued, however this is acceptable + * - see comment in waitForAvailableWriteSpace + */ + if (stillPending < 0) { + log.error("finishWrite({})[{}] - pending byte counts underflow ({}) after {} bytes", getId(), out, stillPending, + writtenSize); + pendingException.compareAndSet(null, + new SshChannelBufferedOutputException(channelId, "Pending byte counts underflow")); + } + } else { + Throwable t = future.getException(); + if (t instanceof SshChannelBufferedOutputException) { + pendingException.compareAndSet(null, (SshChannelBufferedOutputException) t); + } else { + pendingException.compareAndSet(null, new SshChannelBufferedOutputException(channelId, t)); + } + + // In case someone waiting so that they can detect the exception + synchronized (pendingBytesCount) { + pendingBytesCount.notifyAll(); + } + } + writes.remove(future); currentWrite.compareAndSet(future, null); try { startWriting(); } catch (IOException e) { - log.error("finishWrite({}) failed ({}) re-start writing", out, e.getClass().getSimpleName()); + if (e instanceof SshChannelBufferedOutputException) { + pendingException.compareAndSet(null, (SshChannelBufferedOutputException) e); + } else { + pendingException.compareAndSet(null, new SshChannelBufferedOutputException(channelId, e)); + } + log.error("finishWrite({})[{}] failed ({}) re-start writing: {}", + getId(), out, e.getClass().getSimpleName(), e.getMessage(), e); } } - }); - } @Override protected Closeable getInnerCloseable() { - return builder() - .when(getId(), writes) - .close(out) - .build(); + return builder().when(getId(), writes).close(out).build(); } @Override public String toString() { - return getClass().getSimpleName() + "[" + out + "]"; + return getClass().getSimpleName() + "(" + getId() + ")[" + out + "]"; } } diff --git a/sshd-core/src/main/java/org/apache/sshd/common/channel/exception/SshChannelBufferedOutputException.java b/sshd-core/src/main/java/org/apache/sshd/common/channel/exception/SshChannelBufferedOutputException.java new file mode 100644 index 0000000..97e6105 --- /dev/null +++ b/sshd-core/src/main/java/org/apache/sshd/common/channel/exception/SshChannelBufferedOutputException.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sshd.common.channel.exception; + +/** + * Used by the {@code BufferedIoOutputStream} to signal a non-recoverable error + * + * @author Apache MINA SSHD Project + */ +public class SshChannelBufferedOutputException extends SshChannelException { + private static final long serialVersionUID = -8663890657820958046L; + + public SshChannelBufferedOutputException(int channelId, String message) { + this(channelId, message, null); + } + + public SshChannelBufferedOutputException(int channelId, Throwable cause) { + this(channelId, cause.getMessage(), cause); + } + + public SshChannelBufferedOutputException(int channelId, String message, Throwable cause) { + super(channelId, message, cause); + } +} diff --git a/sshd-core/src/main/java/org/apache/sshd/core/CoreModuleProperties.java b/sshd-core/src/main/java/org/apache/sshd/core/CoreModuleProperties.java index 9e9b2d2..0d122e5 100644 --- a/sshd-core/src/main/java/org/apache/sshd/core/CoreModuleProperties.java +++ b/sshd-core/src/main/java/org/apache/sshd/core/CoreModuleProperties.java @@ -24,6 +24,7 @@ import java.time.Duration; import org.apache.sshd.client.config.keys.ClientIdentityLoader; import org.apache.sshd.common.Property; +import org.apache.sshd.common.SshConstants; import org.apache.sshd.common.channel.Channel; import org.apache.sshd.common.session.Session; import org.apache.sshd.common.util.OsUtils; @@ -240,6 +241,24 @@ public final class CoreModuleProperties { public static final Property WINDOW_TIMEOUT = Property.duration("window-timeout", Duration.ZERO); + /** + * Key used when creating a {@code BufferedIoOutputStream} in order to specify max. allowed unwritten pending bytes. + * If this value is exceeded then the code waits up to {@link #BUFFERED_IO_OUTPUT_MAX_PENDING_WRITE_WAIT} for the + * pending data to be written and thus make room for the new request. + */ + public static final Property BUFFERED_IO_OUTPUT_MAX_PENDING_WRITE_SIZE + = Property.integer("buffered-io-output-max-pending-write-size", + SshConstants.SSH_REQUIRED_PAYLOAD_PACKET_LENGTH_SUPPORT * 8); + + /** + * Key used when creating a {@code BufferedIoOutputStream} in order to specify max. wait time (msec.) for pending + * writes to be completed before enqueuing a new request + * + * @see #BUFFERED_IO_OUTPUT_MAX_PENDING_WRITE_SIZE + */ + public static final Property BUFFERED_IO_OUTPUT_MAX_PENDING_WRITE_WAIT + = Property.duration("buffered-io-output-max-pending-write-wait", Duration.ofSeconds(30L)); + /** * Key used to retrieve the value of the maximum packet size in the configuration properties map. */ diff --git a/sshd-core/src/test/java/org/apache/sshd/util/test/AsyncEchoShellFactory.java b/sshd-core/src/test/java/org/apache/sshd/util/test/AsyncEchoShellFactory.java index de9dbf4..465ff83 100644 --- a/sshd-core/src/test/java/org/apache/sshd/util/test/AsyncEchoShellFactory.java +++ b/sshd-core/src/test/java/org/apache/sshd/util/test/AsyncEchoShellFactory.java @@ -99,12 +99,21 @@ public class AsyncEchoShellFactory implements ShellFactory { @Override public void setIoOutputStream(IoOutputStream out) { - this.out = new BufferedIoOutputStream("STDOUT", out); + this.out = wrapOutputStream("SHELL-STDOUT", out); } @Override public void setIoErrorStream(IoOutputStream err) { - this.err = new BufferedIoOutputStream("STDERR", err); + this.err = wrapOutputStream("SHELL-STDERR", err); + } + + protected BufferedIoOutputStream wrapOutputStream(String prefix, IoOutputStream stream) { + if (stream instanceof BufferedIoOutputStream) { + return (BufferedIoOutputStream) stream; + } + + int channelId = session.getId(); + return new BufferedIoOutputStream(prefix + "@" + channelId, channelId, stream, session); } @Override diff --git a/sshd-sftp/src/main/java/org/apache/sshd/server/subsystem/sftp/SftpSubsystem.java b/sshd-sftp/src/main/java/org/apache/sshd/server/subsystem/sftp/SftpSubsystem.java index 66a0ced..15201ec 100644 --- a/sshd-sftp/src/main/java/org/apache/sshd/server/subsystem/sftp/SftpSubsystem.java +++ b/sshd-sftp/src/main/java/org/apache/sshd/server/subsystem/sftp/SftpSubsystem.java @@ -256,7 +256,8 @@ public class SftpSubsystem @Override public void setIoOutputStream(IoOutputStream out) { - this.out = new BufferedIoOutputStream("sftp out buffer", out); + int channelId = channelSession.getId(); + this.out = new BufferedIoOutputStream("sftp-out@" + channelId, channelId, out, channelSession); } @Override -- 2.27.0