Преглед на файлове

Improved connection buffer implementation

Tal Aloni преди 8 години
родител
ревизия
19cb25c463
променени са 4 файла, в които са добавени 155 реда и са изтрити 85 реда
  1. 1 0
      SMBLibrary/SMBLibrary.csproj
  2. 131 0
      SMBLibrary/Server/SMBConnectionReceiveBuffer.cs
  3. 20 81
      SMBLibrary/Server/SMBServer.cs
  4. 3 4
      SMBLibrary/Server/StateObject.cs

+ 1 - 0
SMBLibrary/SMBLibrary.csproj

@@ -116,6 +116,7 @@
     <Compile Include="Server\ResponseHelpers\ServerResponseHelper.cs" />
     <Compile Include="Server\FileSystemShare.cs" />
     <Compile Include="Server\ShareCollection.cs" />
+    <Compile Include="Server\SMBConnectionReceiveBuffer.cs" />
     <Compile Include="Server\SMBServer.cs" />
     <Compile Include="Server\StateObject.cs" />
     <Compile Include="Server\User.cs" />

+ 131 - 0
SMBLibrary/Server/SMBConnectionReceiveBuffer.cs

@@ -0,0 +1,131 @@
+/* Copyright (C) 2014-2017 Tal Aloni <tal.aloni.il@gmail.com>. All rights reserved.
+ * 
+ * You can redistribute this program and/or modify it under the terms of
+ * the GNU Lesser Public License as published by the Free Software Foundation,
+ * either version 3 of the License, or (at your option) any later version.
+ */
+using System;
+using System.Collections.Generic;
+using System.Text;
+using SMBLibrary.NetBios;
+using Utilities;
+
+namespace SMBLibrary.Server
+{
+    public class SMBConnectionReceiveBuffer
+    {
+        private byte[] m_buffer;
+        private int m_readOffset = 0;
+        private int m_bytesInBuffer = 0;
+        private int? m_packetLength;
+
+        /// <param name="bufferLength">Must be large enough to hold the largest possible packet</param>
+        public SMBConnectionReceiveBuffer(int bufferLength)
+        {
+            m_buffer = new byte[bufferLength];
+        }
+
+        public void SetNumberOfBytesReceived(int numberOfBytesReceived)
+        {
+            m_bytesInBuffer += numberOfBytesReceived;
+        }
+
+        public bool HasCompletePacket()
+        {
+            if (m_bytesInBuffer >= 4)
+            {
+                if (!m_packetLength.HasValue)
+                {
+                    // The packet is either Direct TCP transport packet (which is an NBT Session Message
+                    // Packet) or an NBT packet.
+                    byte flags = ByteReader.ReadByte(m_buffer, m_readOffset + 1);
+                    int trailerLength = (flags & 0x01) << 16 | BigEndianConverter.ToUInt16(m_buffer, m_readOffset + 2);
+                    m_packetLength = 4 + trailerLength;
+                }
+                return m_bytesInBuffer >= m_packetLength.Value;
+            }
+            return false;
+        }
+
+        /// <summary>
+        /// HasCompletePacket must be called and return true before calling DequeuePacket
+        /// </summary>
+        /// <exception cref="System.IO.InvalidDataException"></exception>
+        public SessionPacket DequeuePacket()
+        {
+            SessionPacket packet;
+            try
+            {
+                packet = SessionPacket.GetSessionPacket(m_buffer, m_readOffset);
+            }
+            catch (IndexOutOfRangeException ex)
+            {
+                throw new System.IO.InvalidDataException("Invalid Packet", ex);
+            }
+            RemovePacketBytes();
+            return packet;
+        }
+
+        /// <summary>
+        /// HasCompletePDU must be called and return true before calling DequeuePDUBytes
+        /// </summary>
+        public byte[] DequeuePacketBytes()
+        {
+            byte[] packetBytes = ByteReader.ReadBytes(m_buffer, m_readOffset, m_packetLength.Value);
+            RemovePacketBytes();
+            return packetBytes;
+        }
+
+        private void RemovePacketBytes()
+        {
+            m_bytesInBuffer -= m_packetLength.Value;
+            if (m_bytesInBuffer == 0)
+            {
+                m_readOffset = 0;
+                m_packetLength = null;
+            }
+            else
+            {
+                m_readOffset += m_packetLength.Value;
+                m_packetLength = null;
+                if (!HasCompletePacket())
+                {
+                    Array.Copy(m_buffer, m_readOffset, m_buffer, 0, m_bytesInBuffer);
+                    m_readOffset = 0;
+                }
+            }
+        }
+
+        public byte[] Buffer
+        {
+            get
+            {
+                return m_buffer;
+            }
+        }
+
+        public int WriteOffset
+        {
+            get
+            {
+                return m_readOffset + m_bytesInBuffer;
+            }
+        }
+
+        public int BytesInBuffer
+        {
+            get
+            {
+                return m_bytesInBuffer;
+            }
+        }
+
+        public int AvailableLength
+        {
+            get
+            {
+                return m_buffer.Length - (m_readOffset + m_bytesInBuffer);
+            }
+        }
+    }
+}

+ 20 - 81
SMBLibrary/Server/SMBServer.cs

@@ -95,13 +95,12 @@ namespace SMBLibrary.Server
             }
 
             StateObject state = new StateObject();
-            state.ReceiveBuffer = new byte[StateObject.ReceiveBufferSize];
             // Disable the Nagle Algorithm for this tcp socket:
             clientSocket.NoDelay = true;
             state.ClientSocket = clientSocket;
             try
             {
-                clientSocket.BeginReceive(state.ReceiveBuffer, 0, StateObject.ReceiveBufferSize, 0, ReceiveCallback, state);
+                clientSocket.BeginReceive(state.ReceiveBuffer.Buffer, state.ReceiveBuffer.WriteOffset, state.ReceiveBuffer.AvailableLength, 0, ReceiveCallback, state);
             }
             catch (ObjectDisposedException)
             {
@@ -123,13 +122,10 @@ namespace SMBLibrary.Server
                 return;
             }
 
-            byte[] receiveBuffer = state.ReceiveBuffer;
-
-            int bytesReceived;
-
+            int numberOfBytesReceived;
             try
             {
-                bytesReceived = clientSocket.EndReceive(result);
+                numberOfBytesReceived = clientSocket.EndReceive(result);
             }
             catch (ObjectDisposedException)
             {
@@ -140,7 +136,7 @@ namespace SMBLibrary.Server
                 return;
             }
 
-            if (bytesReceived == 0)
+            if (numberOfBytesReceived == 0)
             {
                 // The other side has closed the connection
                 System.Diagnostics.Debug.Print("[{0}] The other side closed the connection", DateTime.Now.ToString("HH:mm:ss:ffff"));
@@ -148,16 +144,15 @@ namespace SMBLibrary.Server
                 return;
             }
 
-            byte[] currentBuffer = new byte[bytesReceived];
-            Array.Copy(receiveBuffer, currentBuffer, bytesReceived);
-
-            ProcessCurrentBuffer(currentBuffer, state);
+            SMBConnectionReceiveBuffer receiveBuffer = state.ReceiveBuffer;
+            receiveBuffer.SetNumberOfBytesReceived(numberOfBytesReceived);
+            ProcessConnectionBuffer(state);
 
             if (clientSocket.Connected)
             {
                 try
                 {
-                    clientSocket.BeginReceive(state.ReceiveBuffer, 0, StateObject.ReceiveBufferSize, 0, ReceiveCallback, state);
+                    clientSocket.BeginReceive(state.ReceiveBuffer.Buffer, state.ReceiveBuffer.WriteOffset, state.ReceiveBuffer.AvailableLength, 0, ReceiveCallback, state);
                 }
                 catch (ObjectDisposedException)
                 {
@@ -168,88 +163,32 @@ namespace SMBLibrary.Server
             }
         }
 
-        public void ProcessCurrentBuffer(byte[] currentBuffer, StateObject state)
+        public void ProcessConnectionBuffer(StateObject state)
         {
             Socket clientSocket = state.ClientSocket;
 
-            if (state.ConnectionBuffer.Length == 0)
-            {
-                state.ConnectionBuffer = currentBuffer;
-            }
-            else
+            SMBConnectionReceiveBuffer receiveBuffer = state.ReceiveBuffer;
+            while (receiveBuffer.HasCompletePacket())
             {
-                byte[] oldConnectionBuffer = state.ConnectionBuffer;
-                state.ConnectionBuffer = new byte[oldConnectionBuffer.Length + currentBuffer.Length];
-                Array.Copy(oldConnectionBuffer, state.ConnectionBuffer, oldConnectionBuffer.Length);
-                Array.Copy(currentBuffer, 0, state.ConnectionBuffer, oldConnectionBuffer.Length, currentBuffer.Length);
-            }
-
-            // we now have all SMB message bytes received so far in state.ConnectionBuffer
-            int bytesLeftInBuffer = state.ConnectionBuffer.Length;
-
-
-            while (bytesLeftInBuffer >= 4)
-            {
-                // The packet is either Direct TCP transport packet (which is an NBT Session Message
-                // Packet) or an NBT packet.
-                int bufferOffset = state.ConnectionBuffer.Length - bytesLeftInBuffer;
-                byte flags = ByteReader.ReadByte(state.ConnectionBuffer, bufferOffset + 1);
-                int trailerLength = (flags & 0x01) << 16 | BigEndianConverter.ToUInt16(state.ConnectionBuffer, bufferOffset + 2);
-                int packetLength = 4 + trailerLength;
-
-                if (flags > 0x01)
+                SessionPacket packet = null;
+                try
                 {
-                    System.Diagnostics.Debug.Print("[{0}] Invalid NBT flags", DateTime.Now.ToString("HH:mm:ss:ffff"));
-                    state.ClientSocket.Close();
-                    return;
+                    packet = receiveBuffer.DequeuePacket();
                 }
-
-                if (packetLength > bytesLeftInBuffer)
+                catch (Exception)
                 {
-                    break;
+                    state.ClientSocket.Close();
                 }
-                else
+
+                if (packet != null)
                 {
-                    byte[] packetBytes = new byte[packetLength];
-                    Array.Copy(state.ConnectionBuffer, bufferOffset, packetBytes, 0, packetLength);
-                    ProcessPacket(packetBytes, state);
-                    bytesLeftInBuffer -= packetLength;
-                    if (!clientSocket.Connected)
-                    {
-                        // Do not continue to process the buffer if the other side closed the connection
-                        return;
-                    }
+                    ProcessPacket(packet, state);
                 }
             }
-
-            if (bytesLeftInBuffer > 0)
-            {
-                byte[] newReceiveBuffer = new byte[bytesLeftInBuffer];
-                Array.Copy(state.ConnectionBuffer, state.ConnectionBuffer.Length - bytesLeftInBuffer, newReceiveBuffer, 0, bytesLeftInBuffer);
-                state.ConnectionBuffer = newReceiveBuffer;
-            }
-            else
-            {
-                state.ConnectionBuffer = new byte[0];
-            }
         }
 
-        public void ProcessPacket(byte[] packetBytes, StateObject state)
+        public void ProcessPacket(SessionPacket packet, StateObject state)
         {
-            SessionPacket packet = null;
-#if DEBUG
-            packet = SessionPacket.GetSessionPacket(packetBytes, 0);
-#else
-            try
-            {
-                packet = SessionPacket.GetSessionPacket(packetBytes, 0);
-            }
-            catch (Exception)
-            {
-                state.ClientSocket.Close();
-                return;
-            }
-#endif
             if (packet is SessionRequestPacket && m_transport == SMBTransportType.NetBiosOverTCP)
             {
                 PositiveSessionResponsePacket response = new PositiveSessionResponsePacket();

+ 3 - 4
SMBLibrary/Server/StateObject.cs

@@ -1,4 +1,4 @@
-/* Copyright (C) 2014-2016 Tal Aloni <tal.aloni.il@gmail.com>. All rights reserved.
+/* Copyright (C) 2014-2017 Tal Aloni <tal.aloni.il@gmail.com>. All rights reserved.
  * 
  * You can redistribute this program and/or modify it under the terms of
  * the GNU Lesser Public License as published by the Free Software Foundation,
@@ -16,9 +16,8 @@ namespace SMBLibrary.Server
     public class StateObject
     {
         public Socket ClientSocket = null;
-        public const int ReceiveBufferSize = 65536;
-        public byte[] ReceiveBuffer = new byte[ReceiveBufferSize]; // immediate receive buffer
-        public byte[] ConnectionBuffer = new byte[0]; // we append the receive buffer here until we have a complete Message
+        public const int ReceiveBufferSize = 131075; // Largest NBT Session Packet
+        public SMBConnectionReceiveBuffer ReceiveBuffer = new SMBConnectionReceiveBuffer(ReceiveBufferSize);
 
         public int MaxBufferSize;
         public bool LargeRead;