From fa46de7feeeb8a01dc471581a0258252ce4f2db6 Mon Sep 17 00:00:00 2001 From: Jeff Forcier Date: Sat, 16 Dec 2023 17:12:42 -0500 Subject: [PATCH] Reset sequence numbers on rekey Reference:https://github.com/paramiko/paramiko/commit/fa46de7feeeb8a01dc471581a0258252ce4f2db6 Conflict:NA --- paramiko/packet.py | 6 ++++++ paramiko/transport.py | 22 ++++++++++++++++++++-- tests/test_transport.py | 25 +++++++++++++++++++++++-- 3 files changed, 49 insertions(+), 4 deletions(-) diff --git a/paramiko/packet.py b/paramiko/packet.py index 1266316..1fc06d9 100644 --- a/paramiko/packet.py +++ b/paramiko/packet.py @@ -130,6 +130,12 @@ class Packetizer(object): def closed(self): return self.__closed + def reset_seqno_out(self): + self.__sequence_number_out = 0 + + def reset_seqno_in(self): + self.__sequence_number_in = 0 + def set_log(self, log): """ Set the Python log object to use for logging. diff --git a/paramiko/transport.py b/paramiko/transport.py index 83b1c81..0c68668 100644 --- a/paramiko/transport.py +++ b/paramiko/transport.py @@ -2469,9 +2469,13 @@ class Transport(threading.Thread, ClosingContextManager): # CVE mitigation: expect zeroed-out seqno anytime we are performing kex # init phase, if strict mode was negotiated. - if self.agreed_on_strict_kex and m.seqno != 0: + if ( + self.agreed_on_strict_kex + and not self.initial_kex_done + and m.seqno != 0 + ): raise MessageOrderError( - f"Got nonzero seqno ({m.seqno}) during strict KEXINIT!" + "In strict-kex mode, but KEXINIT was not the first packet!" ) # as a server, we pick the first item in the client's list that we @@ -2670,6 +2674,13 @@ class Transport(threading.Thread, ClosingContextManager): ): self._log(DEBUG, "Switching on inbound compression ...") self.packetizer.set_inbound_compressor(compress_in()) + # Reset inbound sequence number if strict mode. + if self.agreed_on_strict_kex: + self._log( + DEBUG, + f"Resetting inbound seqno after NEWKEYS due to strict mode", + ) + self.packetizer.reset_seqno_in() def _activate_outbound(self): """switch on newly negotiated encryption parameters for @@ -2677,6 +2688,13 @@ class Transport(threading.Thread, ClosingContextManager): m = Message() m.add_byte(cMSG_NEWKEYS) self._send_message(m) + # Reset outbound sequence number if strict mode. + if self.agreed_on_strict_kex: + self._log( + DEBUG, + f"Resetting outbound sequence number after NEWKEYS due to strict mode", + ) + self.packetizer.reset_seqno_out() block_size = self._cipher_info[self.local_cipher]["block-size"] if self.server_mode: IV_out = self._compute_key("B", block_size) diff --git a/tests/test_transport.py b/tests/test_transport.py index 7440e88..9c3e8f5 100644 --- a/tests/test_transport.py +++ b/tests/test_transport.py @@ -1548,5 +1548,26 @@ class TestStrictKex: ): pass # kexinit happens at connect... - def test_sequence_numbers_reset_on_newkeys(self): - skip() + def test_sequence_numbers_reset_on_newkeys_when_strict(self): + with server(defer=True) as (tc, ts): + # When in strict mode, these should all be zero or close to it + # (post-kexinit, pre-auth). + # Server->client will be 1 (EXT_INFO got sent after NEWKEYS) + assert tc.packetizer._Packetizer__sequence_number_in == 1 + assert ts.packetizer._Packetizer__sequence_number_out == 1 + # Client->server will be 0 + assert tc.packetizer._Packetizer__sequence_number_out == 0 + assert ts.packetizer._Packetizer__sequence_number_in == 0 + + def test_sequence_numbers_not_reset_on_newkeys_when_not_strict(self): + with server(defer=True, client_init=dict(strict_kex=False)) as ( + tc, + ts, + ): + # When not in strict mode, these will all be ~3-4 or so + # (post-kexinit, pre-auth). Not encoding exact values as it will + # change anytime we mess with the test harness... + assert tc.packetizer._Packetizer__sequence_number_in != 0 + assert tc.packetizer._Packetizer__sequence_number_out != 0 + assert ts.packetizer._Packetizer__sequence_number_in != 0 + assert ts.packetizer._Packetizer__sequence_number_out != 0 -- 2.33.0