Browse Source

SMBServer: Each connection now use a dedicaded thread for send operations

Tal Aloni 8 years ago
parent
commit
84affda0ff

+ 1 - 0
SMBLibrary/Server/ConnectionManager.cs

@@ -38,6 +38,7 @@ namespace SMBLibrary.Server
 
         public void ReleaseConnection(ConnectionState connection)
         {
+            connection.SendQueue.Stop();
             SocketUtils.ReleaseSocket(connection.ClientSocket);
             RemoveConnection(connection);
         }

+ 3 - 0
SMBLibrary/Server/ConnectionState/ConnectionState.cs

@@ -20,6 +20,7 @@ namespace SMBLibrary.Server
         public Socket ClientSocket;
         public IPEndPoint ClientEndPoint;
         public NBTConnectionReceiveBuffer ReceiveBuffer;
+        public BlockingQueue<SessionPacket> SendQueue;
         protected LogDelegate LogToServerHandler;
         public SMBDialect Dialect;
         public object AuthenticationContext;
@@ -27,6 +28,7 @@ namespace SMBLibrary.Server
         public ConnectionState(LogDelegate logToServerHandler)
         {
             ReceiveBuffer = new NBTConnectionReceiveBuffer();
+            SendQueue = new BlockingQueue<SessionPacket>();
             LogToServerHandler = logToServerHandler;
             Dialect = SMBDialect.NotSet;
         }
@@ -36,6 +38,7 @@ namespace SMBLibrary.Server
             ClientSocket = state.ClientSocket;
             ClientEndPoint = state.ClientEndPoint;
             ReceiveBuffer = state.ReceiveBuffer;
+            SendQueue = state.SendQueue;
             LogToServerHandler = state.LogToServerHandler;
             Dialect = state.Dialect;
         }

+ 5 - 5
SMBLibrary/Server/SMBServer.SMB1.cs

@@ -49,7 +49,7 @@ namespace SMBLibrary.Server
                             index--;
                         }
                     }
-                    TrySendMessage(state, reply);
+                    EnqueueMessage(state, reply);
                 }
             }
 
@@ -58,7 +58,7 @@ namespace SMBLibrary.Server
                 SMB1Message reply = new SMB1Message();
                 reply.Header = header;
                 reply.Commands.Add(response);
-                TrySendMessage(state, reply);
+                EnqueueMessage(state, reply);
             }
         }
 
@@ -302,12 +302,12 @@ namespace SMBLibrary.Server
             return new ErrorResponse(command.CommandName);
         }
 
-        private static void TrySendMessage(ConnectionState state, SMB1Message response)
+        private static void EnqueueMessage(ConnectionState state, SMB1Message response)
         {
             SessionMessagePacket packet = new SessionMessagePacket();
             packet.Trailer = response.GetBytes();
-            TrySendPacket(state, packet);
-            state.LogToServer(Severity.Verbose, "SMB1 message sent: {0} responses, First response: {1}, Packet length: {2}", response.Commands.Count, response.Commands[0].CommandName.ToString(), packet.Length);
+            state.SendQueue.Enqueue(packet);
+            state.LogToServer(Severity.Verbose, "SMB1 message queued: {0} responses, First response: {1}, Packet length: {2}", response.Commands.Count, response.Commands[0].CommandName.ToString(), packet.Length);
         }
 
         private static void PrepareResponseHeader(SMB1Header responseHeader, SMB1Header requestHeader)

+ 7 - 7
SMBLibrary/Server/SMBServer.SMB2.cs

@@ -60,7 +60,7 @@ namespace SMBLibrary.Server
             }
             if (responseChain.Count > 0)
             {
-                TrySendResponseChain(state, responseChain);
+                EnqueueResponseChain(state, responseChain);
             }
         }
 
@@ -223,15 +223,15 @@ namespace SMBLibrary.Server
             return new ErrorResponse(command.CommandName, NTStatus.STATUS_NOT_SUPPORTED);
         }
 
-        private static void TrySendResponse(ConnectionState state, SMB2Command response)
+        private static void EnqueueResponse(ConnectionState state, SMB2Command response)
         {
             SessionMessagePacket packet = new SessionMessagePacket();
             packet.Trailer = response.GetBytes();
-            TrySendPacket(state, packet);
-            state.LogToServer(Severity.Verbose, "SMB2 response sent: {0}, Packet length: {1}", response.CommandName.ToString(), packet.Length);
+            state.SendQueue.Enqueue(packet);
+            state.LogToServer(Severity.Verbose, "SMB2 response queued: {0}, Packet length: {1}", response.CommandName.ToString(), packet.Length);
         }
 
-        private static void TrySendResponseChain(ConnectionState state, List<SMB2Command> responseChain)
+        private static void EnqueueResponseChain(ConnectionState state, List<SMB2Command> responseChain)
         {
             byte[] sessionKey = null;
             if (state is SMB2ConnectionState)
@@ -252,8 +252,8 @@ namespace SMBLibrary.Server
 
             SessionMessagePacket packet = new SessionMessagePacket();
             packet.Trailer = SMB2Command.GetCommandChainBytes(responseChain, sessionKey);
-            TrySendPacket(state, packet);
-            state.LogToServer(Severity.Verbose, "SMB2 response chain sent: Response count: {0}, First response: {1}, Packet length: {2}", responseChain.Count, responseChain[0].CommandName.ToString(), packet.Length);
+            state.SendQueue.Enqueue(packet);
+            state.LogToServer(Severity.Verbose, "SMB2 response chain queued: Response count: {0}, First response: {1}, Packet length: {2}", responseChain.Count, responseChain[0].CommandName.ToString(), packet.Length);
         }
 
         private static void UpdateSMB2Header(SMB2Command response, SMB2Command request)

+ 36 - 12
SMBLibrary/Server/SMBServer.cs

@@ -8,6 +8,7 @@ using System;
 using System.Collections.Generic;
 using System.Net;
 using System.Net.Sockets;
+using System.Threading;
 using SMBLibrary.Authentication.GSSAPI;
 using SMBLibrary.NetBios;
 using SMBLibrary.Services;
@@ -117,6 +118,13 @@ namespace SMBLibrary.Server
             state.ClientSocket = clientSocket;
             state.ClientEndPoint = clientSocket.RemoteEndPoint as IPEndPoint;
             state.LogToServer(Severity.Verbose, "New connection request");
+            Thread senderThread = new Thread(delegate()
+            {
+                ProcessSendQueue(state);
+            });
+            senderThread.IsBackground = true;
+            senderThread.Start();
+
             try
             {
                 // Direct TCP transport packet is actually an NBT Session Message Packet,
@@ -219,7 +227,7 @@ namespace SMBLibrary.Server
             if (packet is SessionRequestPacket && m_transport == SMBTransportType.NetBiosOverTCP)
             {
                 PositiveSessionResponsePacket response = new PositiveSessionResponsePacket();
-                TrySendPacket(state, response);
+                state.SendQueue.Enqueue(response);
             }
             else if (packet is SessionKeepAlivePacket && m_transport == SMBTransportType.NetBiosOverTCP)
             {
@@ -265,7 +273,7 @@ namespace SMBLibrary.Server
                                 state = new SMB2ConnectionState(state, AllocatePersistentFileID);
                                 m_connectionManager.AddConnection(state);
                             }
-                            TrySendResponse(state, response);
+                            EnqueueResponse(state, response);
                             return;
                         }
                     }
@@ -319,18 +327,34 @@ namespace SMBLibrary.Server
             }
         }
 
-        private static void TrySendPacket(ConnectionState state, SessionPacket response)
+        private void ProcessSendQueue(ConnectionState state)
         {
-            Socket clientSocket = state.ClientSocket;
-            try
-            {
-                clientSocket.Send(response.GetBytes());
-            }
-            catch (SocketException)
-            {
-            }
-            catch (ObjectDisposedException)
+            while (true)
             {
+                Log(Severity.Trace, "Entering ProcessSendQueue");
+                SessionPacket response;
+                bool stopped = !state.SendQueue.TryDequeue(out response);
+                if (stopped)
+                {
+                    return;
+                }
+                Socket clientSocket = state.ClientSocket;
+                try
+                {
+                    clientSocket.Send(response.GetBytes());
+                }
+                catch (SocketException ex)
+                {
+                    Log(Severity.Debug, "[{0}] Failed to send packet. SocketException: {1}", state.ConnectionIdentifier, ex.Message);
+                    Log(Severity.Trace, "Leaving ProcessSendQueue");
+                    return;
+                }
+                catch (ObjectDisposedException)
+                {
+                    Log(Severity.Debug, "[{0}] Failed to send packet. ObjectDisposedException.", state.ConnectionIdentifier);
+                    Log(Severity.Trace, "Leaving ProcessSendQueue");
+                    return;
+                }
             }
         }