Browse Source

SMBServer: Improved Named Pipe implementation

Tal Aloni 7 years ago
parent
commit
4aba3f5244

+ 10 - 1
SMBLibrary/NTFileStore/NamedPipeStore.cs

@@ -110,12 +110,21 @@ namespace SMBLibrary
                 {
                     return writeStatus;
                 }
+                int messageLength = ((RPCPipeStream)((FileHandle)handle).Stream).MessageLength;
                 NTStatus readStatus = ReadFile(out output, handle, 0, maxOutputLength);
                 if (readStatus != NTStatus.STATUS_SUCCESS)
                 {
                     return readStatus;
                 }
-                return NTStatus.STATUS_SUCCESS;
+
+                if (output.Length < messageLength)
+                {
+                    return NTStatus.STATUS_BUFFER_OVERFLOW;
+                }
+                else
+                {
+                    return NTStatus.STATUS_SUCCESS;
+                }
             }
 
             return NTStatus.STATUS_NOT_SUPPORTED;

+ 65 - 13
SMBLibrary/Services/RPCPipeStream.cs

@@ -14,17 +14,30 @@ namespace SMBLibrary.Services
     public class RPCPipeStream : Stream
     {
         private RemoteService m_service;
-        private MemoryStream m_outputStream;
+        private List<MemoryStream> m_outputStreams; // A stream for each message in order to support message mode named pipe
+        private int? m_maxTransmitFragmentSize;
 
         public RPCPipeStream(RemoteService service)
         {
             m_service = service;
-            m_outputStream = new MemoryStream();
+            m_outputStreams = new List<MemoryStream>();
         }
 
         public override int Read(byte[] buffer, int offset, int count)
         {
-            return m_outputStream.Read(buffer, offset, count);
+            if (m_outputStreams.Count > 0)
+            {
+                int result = m_outputStreams[0].Read(buffer, offset, count);
+                if (m_outputStreams[0].Position == m_outputStreams[0].Length)
+                {
+                    m_outputStreams.RemoveAt(0);
+                }
+                return result;
+            }
+            else
+            {
+                return 0;
+            }
         }
 
         public override void Write(byte[] buffer, int offset, int count)
@@ -33,20 +46,42 @@ namespace SMBLibrary.Services
             do
             {
                 RPCPDU rpcRequest = RPCPDU.GetPDU(buffer, offset);
+                ProcessRPCRequest(rpcRequest);
                 lengthOfPDUs += rpcRequest.FragmentLength;
-                RPCPDU rpcReply = RemoteServiceHelper.GetRPCReply(rpcRequest, m_service);
-                byte[] replyData = rpcReply.GetBytes();
-                Append(replyData);
             }
             while (lengthOfPDUs < count);
         }
 
+        private void ProcessRPCRequest(RPCPDU rpcRequest)
+        {
+            if (rpcRequest is BindPDU)
+            {
+                BindAckPDU bindAckPDU = RemoteServiceHelper.GetRPCBindResponse((BindPDU)rpcRequest, m_service);
+                m_maxTransmitFragmentSize = bindAckPDU.MaxTransmitFragmentSize;
+                Append(bindAckPDU.GetBytes());
+            }
+            else if (rpcRequest is RequestPDU)
+            {
+                // if BindPDU was not received, we ignore any subsequent RPC packets
+                if (m_maxTransmitFragmentSize.HasValue)
+                {
+                    List<ResponsePDU> responsePDUs = RemoteServiceHelper.GetRPCResponse((RequestPDU)rpcRequest, m_service, m_maxTransmitFragmentSize.Value);
+                    foreach (ResponsePDU responsePDU in responsePDUs)
+                    {
+                        Append(responsePDU.GetBytes());
+                    }
+                }
+            }
+            else
+            {
+                throw new NotImplementedException("Unsupported RPC Packet Type");
+            }
+        }
+
         private void Append(byte[] buffer)
         {
-            long position = m_outputStream.Position;
-            m_outputStream.Position = m_outputStream.Length;
-            m_outputStream.Write(buffer, 0, buffer.Length);
-            m_outputStream.Seek(position, SeekOrigin.Begin);
+            MemoryStream stream = new MemoryStream(buffer);
+            m_outputStreams.Add(stream);
         }
 
         public override void Flush()
@@ -55,7 +90,6 @@ namespace SMBLibrary.Services
 
         public override void Close()
         {
-            m_outputStream.Close();
         }
 
         public override long Seek(long offset, SeekOrigin origin)
@@ -80,7 +114,7 @@ namespace SMBLibrary.Services
         {
             get
             {
-                return m_outputStream.CanRead;
+                return true;
             }
         }
 
@@ -88,7 +122,7 @@ namespace SMBLibrary.Services
         {
             get
             {
-                return m_outputStream.CanWrite;
+                return true;
             }
         }
 
@@ -112,5 +146,23 @@ namespace SMBLibrary.Services
                 throw new NotSupportedException();
             }
         }
+
+        /// <summary>
+        /// The length of the first message available in the pipe
+        /// </summary>
+        public int MessageLength
+        {
+            get
+            {
+                if (m_outputStreams.Count > 0)
+                {
+                    return (int)m_outputStreams[0].Length;
+                }
+                else
+                {
+                    return 0;
+                }
+            }
+        }
     }
 }

+ 33 - 33
SMBLibrary/Services/RemoteServiceHelper.cs

@@ -24,26 +24,12 @@ namespace SMBLibrary.Services
         private static readonly Guid BindTimeFeatureIdentifier3 = new Guid("6CB71C2C-9812-4540-0300-000000000000");
         private static uint m_associationGroupID = 1;
 
-        public static RPCPDU GetRPCReply(RPCPDU pdu, RemoteService service)
-        {
-            if (pdu is BindPDU)
-            {
-                return GetRPCBindResponse((BindPDU)pdu, service);
-            }
-            else if (pdu is RequestPDU)
-            {
-                return GetRPCResponse((RequestPDU)pdu, service);
-            }
-            else
-            {
-                throw new NotImplementedException();
-            }
-        }
-
-        private static BindAckPDU GetRPCBindResponse(BindPDU bindPDU, RemoteService service)
+        public static BindAckPDU GetRPCBindResponse(BindPDU bindPDU, RemoteService service)
         {
             BindAckPDU bindAckPDU = new BindAckPDU();
-            PrepareReply(bindAckPDU, bindPDU);
+            bindAckPDU.Flags = PacketFlags.FirstFragment | PacketFlags.LastFragment;
+            bindAckPDU.DataRepresentation = bindPDU.DataRepresentation;
+            bindAckPDU.CallID = bindPDU.CallID;
             // See DCE 1.1: Remote Procedure Call - 12.6.3.6
             // The client should set the assoc_group_id field either to 0 (zero), to indicate a new association group,
             // or to the known value. When the server receives a value of 0, this indicates that the client
@@ -62,8 +48,8 @@ namespace SMBLibrary.Services
                 bindAckPDU.AssociationGroupID = bindPDU.AssociationGroupID;
             }
             bindAckPDU.SecondaryAddress = @"\PIPE\" + service.PipeName;
-            bindAckPDU.MaxReceiveFragmentSize = bindPDU.MaxReceiveFragmentSize;
-            bindAckPDU.MaxTransmitFragmentSize = bindPDU.MaxTransmitFragmentSize;
+            bindAckPDU.MaxTransmitFragmentSize = bindPDU.MaxReceiveFragmentSize;
+            bindAckPDU.MaxReceiveFragmentSize = bindPDU.MaxTransmitFragmentSize;
             foreach (ContextElement element in bindPDU.ContextList)
             {
                 ResultElement resultElement = new ResultElement();
@@ -117,20 +103,34 @@ namespace SMBLibrary.Services
             return -1;
         }
 
-        private static ResponsePDU GetRPCResponse(RequestPDU requestPDU, RemoteService service)
+        public static List<ResponsePDU> GetRPCResponse(RequestPDU requestPDU, RemoteService service, int maxTransmitFragmentSize)
         {
-            ResponsePDU responsePDU = new ResponsePDU();
-            PrepareReply(responsePDU, requestPDU);
-            responsePDU.Data = service.GetResponseBytes(requestPDU.OpNum, requestPDU.Data);
-            responsePDU.AllocationHint = (uint)responsePDU.Data.Length;
-            return responsePDU;
-        }
-
-        private static void PrepareReply(RPCPDU reply, RPCPDU request)
-        {
-            reply.DataRepresentation = request.DataRepresentation;
-            reply.CallID = request.CallID;
-            reply.Flags = PacketFlags.FirstFragment | PacketFlags.LastFragment;
+            byte[] responseBytes = service.GetResponseBytes(requestPDU.OpNum, requestPDU.Data);
+            int offset = 0;
+            List<ResponsePDU> result = new List<ResponsePDU>();
+            int maxPDUDataLength = maxTransmitFragmentSize - RPCPDU.CommonFieldsLength - ResponsePDU.ResponseFieldsLength;
+            do
+            {
+                ResponsePDU responsePDU = new ResponsePDU();
+                int pduDataLength = Math.Min(responseBytes.Length - offset, maxPDUDataLength);
+                responsePDU.DataRepresentation = requestPDU.DataRepresentation;
+                responsePDU.CallID = requestPDU.CallID;
+                responsePDU.AllocationHint = (uint)(responseBytes.Length - offset);
+                responsePDU.Data = ByteReader.ReadBytes(responseBytes, offset, pduDataLength);
+                if (offset == 0)
+                {
+                    responsePDU.Flags |= PacketFlags.FirstFragment;
+                }
+                if (offset + pduDataLength == responseBytes.Length)
+                {
+                    responsePDU.Flags |= PacketFlags.LastFragment;
+                }
+                result.Add(responsePDU);
+                offset += pduDataLength;
+            }
+            while (offset < responseBytes.Length);
+            
+            return result;
         }
     }
 }