Browse Source

Refactored ReadWriteResponseHelper

Tal Aloni 8 years ago
parent
commit
d42240ffa1
1 changed files with 76 additions and 73 deletions
  1. 76 73
      SMBLibrary/Server/SMB1/ReadWriteResponseHelper.cs

+ 76 - 73
SMBLibrary/Server/SMB1/ReadWriteResponseHelper.cs

@@ -19,10 +19,17 @@ namespace SMBLibrary.Server.SMB1
     {
         internal static SMB1Command GetReadResponse(SMB1Header header, ReadRequest request, ISMBShare share, SMB1ConnectionState state)
         {
-            byte[] data = PerformRead(header, share, request.FID, request.ReadOffsetInBytes, request.CountOfBytesToRead, state);
+            OpenedFileObject openedFile = state.GetOpenedFileObject(request.FID);
+            if (openedFile == null)
+            {
+                header.Status = NTStatus.STATUS_INVALID_HANDLE;
+                return null;
+            }
+            byte[] data;
+            header.Status = ReadFile(out data, openedFile, request.ReadOffsetInBytes, request.CountOfBytesToRead, state);
             if (header.Status != NTStatus.STATUS_SUCCESS)
             {
-                return new ErrorResponse(CommandName.SMB_COM_READ);
+                return new ErrorResponse(request.CommandName);
             }
 
             ReadResponse response = new ReadResponse();
@@ -33,15 +40,22 @@ namespace SMBLibrary.Server.SMB1
 
         internal static SMB1Command GetReadResponse(SMB1Header header, ReadAndXRequest request, ISMBShare share, SMB1ConnectionState state)
         {
+            OpenedFileObject openedFile = state.GetOpenedFileObject(request.FID);
+            if (openedFile == null)
+            {
+                header.Status = NTStatus.STATUS_INVALID_HANDLE;
+                return null;
+            }
             uint maxCount = request.MaxCount;
             if ((share is FileSystemShare) && state.LargeRead)
             {
                 maxCount = request.MaxCountLarge;
             }
-            byte[] data = PerformRead(header, share, request.FID, request.Offset, maxCount, state);
+            byte[] data;
+            header.Status = ReadFile(out data, openedFile, (long)request.Offset, (int)maxCount, state);
             if (header.Status != NTStatus.STATUS_SUCCESS)
             {
-                return new ErrorResponse(CommandName.SMB_COM_READ_ANDX);
+                return new ErrorResponse(request.CommandName);
             }
 
             ReadAndXResponse response = new ReadAndXResponse();
@@ -54,47 +68,31 @@ namespace SMBLibrary.Server.SMB1
             return response;
         }
 
-        public static byte[] PerformRead(SMB1Header header, ISMBShare share, ushort FID, ulong offset, uint maxCount, SMB1ConnectionState state)
+        public static NTStatus ReadFile(out byte[] data, OpenedFileObject openedFile, long offset, int maxCount, ConnectionState state)
         {
-            if (offset > Int64.MaxValue || maxCount > Int32.MaxValue)
-            {
-                throw new NotImplementedException("Underlying filesystem does not support unsigned offset / read count");
-            }
-            return PerformRead(header, share, FID, (long)offset, (int)maxCount, state);
-        }
-
-        public static byte[] PerformRead(SMB1Header header, ISMBShare share, ushort FID, long offset, int maxCount, SMB1ConnectionState state)
-        {
-            OpenedFileObject openedFile = state.GetOpenedFileObject(FID);
-            if (openedFile == null)
-            {
-                header.Status = NTStatus.STATUS_INVALID_HANDLE;
-                return null;
-            }
+            data = null;
             string openedFilePath = openedFile.Path;
             Stream stream = openedFile.Stream;
-            if (share is NamedPipeShare)
+            if (stream is RPCPipeStream)
             {
-                byte[] data = new byte[maxCount];
+                data = new byte[maxCount];
                 int bytesRead = stream.Read(data, 0, maxCount);
                 if (bytesRead < maxCount)
                 {
                     // EOF, we must trim the response data array
                     data = ByteReader.ReadBytes(data, 0, bytesRead);
                 }
-                return data;
+                return NTStatus.STATUS_SUCCESS;
             }
-            else // FileSystemShare
+            else // File
             {
-
                 if (stream == null)
                 {
-                    header.Status = NTStatus.STATUS_ACCESS_DENIED;
-                    return null;
+                    state.LogToServer(Severity.Debug, "ReadFile: Cannot read '{0}', Invalid Operation.", openedFilePath);
+                    return NTStatus.STATUS_ACCESS_DENIED;
                 }
 
                 int bytesRead;
-                byte[] data;
                 try
                 {
                     stream.Seek(offset, SeekOrigin.Begin);
@@ -107,28 +105,24 @@ namespace SMBLibrary.Server.SMB1
                     if (errorCode == (ushort)Win32Error.ERROR_SHARING_VIOLATION)
                     {
                         // Returning STATUS_SHARING_VIOLATION is undocumented but apparently valid
-                        state.LogToServer(Severity.Debug, "ReadAndX: Cannot read '{0}'. Sharing Violation.", openedFilePath);
-                        header.Status = NTStatus.STATUS_SHARING_VIOLATION;
-                        return null;
+                        state.LogToServer(Severity.Debug, "ReadFile: Cannot read '{0}'. Sharing Violation.", openedFilePath);
+                        return NTStatus.STATUS_SHARING_VIOLATION;
                     }
                     else
                     {
-                        state.LogToServer(Severity.Debug, "ReadAndX: Cannot read '{0}'. Data Error.", openedFilePath);
-                        header.Status = NTStatus.STATUS_DATA_ERROR;
-                        return null;
+                        state.LogToServer(Severity.Debug, "ReadFile: Cannot read '{0}'. Data Error.", openedFilePath);
+                        return NTStatus.STATUS_DATA_ERROR;
                     }
                 }
                 catch (ArgumentOutOfRangeException)
                 {
-                    state.LogToServer(Severity.Debug, "ReadAndX: Cannot read '{0}'. Offset Out Of Range.", openedFilePath);
-                    header.Status = NTStatus.STATUS_DATA_ERROR;
-                    return null;
+                    state.LogToServer(Severity.Debug, "ReadFile: Cannot read '{0}'. Offset Out Of Range.", openedFilePath);
+                    return NTStatus.STATUS_DATA_ERROR;
                 }
                 catch (UnauthorizedAccessException)
                 {
-                    state.LogToServer(Severity.Debug, "ReadAndX: Cannot read '{0}', Access Denied.", openedFilePath);
-                    header.Status = NTStatus.STATUS_ACCESS_DENIED;
-                    return null;
+                    state.LogToServer(Severity.Debug, "ReadFile: Cannot read '{0}', Access Denied.", openedFilePath);
+                    return NTStatus.STATUS_ACCESS_DENIED;
                 }
 
                 if (bytesRead < maxCount)
@@ -136,33 +130,45 @@ namespace SMBLibrary.Server.SMB1
                     // EOF, we must trim the response data array
                     data = ByteReader.ReadBytes(data, 0, bytesRead);
                 }
-                return data;
+                return NTStatus.STATUS_SUCCESS;
             }
         }
 
         internal static SMB1Command GetWriteResponse(SMB1Header header, WriteRequest request, ISMBShare share, SMB1ConnectionState state)
         {
-            ushort bytesWritten = (ushort)PerformWrite(header, share, request.FID, request.WriteOffsetInBytes, request.Data, state);
+            OpenedFileObject openedFile = state.GetOpenedFileObject(request.FID);
+            if (openedFile == null)
+            {
+                header.Status = NTStatus.STATUS_INVALID_HANDLE;
+                return new ErrorResponse(request.CommandName);
+            }
+            int numberOfBytesWritten;
+            header.Status = WriteFile(out numberOfBytesWritten, openedFile, request.WriteOffsetInBytes, request.Data, state);
             if (header.Status != NTStatus.STATUS_SUCCESS)
             {
-                return new ErrorResponse(CommandName.SMB_COM_WRITE_ANDX);
+                return new ErrorResponse(request.CommandName);
             }
             WriteResponse response = new WriteResponse();
-            response.CountOfBytesWritten = bytesWritten;
-
+            response.CountOfBytesWritten = (ushort)numberOfBytesWritten;
             return response;
         }
 
         internal static SMB1Command GetWriteResponse(SMB1Header header, WriteAndXRequest request, ISMBShare share, SMB1ConnectionState state)
         {
-            uint bytesWritten = PerformWrite(header, share, request.FID, request.Offset, request.Data, state);
+            OpenedFileObject openedFile = state.GetOpenedFileObject(request.FID);
+            if (openedFile == null)
+            {
+                header.Status = NTStatus.STATUS_INVALID_HANDLE;
+                return new ErrorResponse(request.CommandName);
+            }
+            int numberOfBytesWritten;
+            header.Status = WriteFile(out numberOfBytesWritten, openedFile, (long)request.Offset, request.Data, state);
             if (header.Status != NTStatus.STATUS_SUCCESS)
             {
-                return new ErrorResponse(CommandName.SMB_COM_WRITE_ANDX);
+                return new ErrorResponse(request.CommandName);
             }
             WriteAndXResponse response = new WriteAndXResponse();
-            response.Count = bytesWritten;
-
+            response.Count = (uint)numberOfBytesWritten;
             if (share is FileSystemShare)
             {
                 // If the client wrote to a disk file, this field MUST be set to 0xFFFF.
@@ -171,65 +177,62 @@ namespace SMBLibrary.Server.SMB1
             return response;
         }
 
-        public static uint PerformWrite(SMB1Header header, ISMBShare share, ushort FID, ulong offset, byte[] data, SMB1ConnectionState state)
+        public static NTStatus WriteFile(out int numberOfBytesWritten, OpenedFileObject openedFile, long offset, byte[] data, ConnectionState state)
         {
-            OpenedFileObject openedFile = state.GetOpenedFileObject(FID);
-            if (openedFile == null)
-            {
-                header.Status = NTStatus.STATUS_INVALID_HANDLE;
-                return 0;
-            }
+            numberOfBytesWritten = 0;
             string openedFilePath = openedFile.Path;
             Stream stream = openedFile.Stream;
-            if (share is NamedPipeShare)
+            if (stream is RPCPipeStream)
             {
                 stream.Write(data, 0, data.Length);
-                return (uint)data.Length;
+                numberOfBytesWritten = data.Length;
+                return NTStatus.STATUS_SUCCESS;
             }
-            else // FileSystemShare
+            else // File
             {
                 if (stream == null)
                 {
-                     header.Status = NTStatus.STATUS_ACCESS_DENIED;
-                     return 0;
+                    state.LogToServer(Severity.Debug, "WriteFile: Cannot write '{0}'. Invalid Operation.", openedFilePath);
+                    return NTStatus.STATUS_ACCESS_DENIED;
                 }
 
                 try
                 {
-                    stream.Seek((long)offset, SeekOrigin.Begin);
+                    stream.Seek(offset, SeekOrigin.Begin);
                     stream.Write(data, 0, data.Length);
-                    return (uint)data.Length;
+                    numberOfBytesWritten = data.Length;
+                    return NTStatus.STATUS_SUCCESS;
                 }
                 catch (IOException ex)
                 {
                     ushort errorCode = IOExceptionHelper.GetWin32ErrorCode(ex);
                     if (errorCode == (ushort)Win32Error.ERROR_DISK_FULL)
                     {
-                        header.Status = NTStatus.STATUS_DISK_FULL;
-                        return 0;
+                        state.LogToServer(Severity.Debug, "WriteFile: Cannot write '{0}'. Disk Full.", openedFilePath);
+                        return NTStatus.STATUS_DISK_FULL;
                     }
                     else if (errorCode == (ushort)Win32Error.ERROR_SHARING_VIOLATION)
                     {
+                        state.LogToServer(Severity.Debug, "WriteFile: Cannot write '{0}'. Sharing Violation.", openedFilePath);
                         // Returning STATUS_SHARING_VIOLATION is undocumented but apparently valid
-                        header.Status = NTStatus.STATUS_SHARING_VIOLATION;
-                        return 0;
+                        return NTStatus.STATUS_SHARING_VIOLATION;
                     }
                     else
                     {
-                        header.Status = NTStatus.STATUS_DATA_ERROR;
-                        return 0;
+                        state.LogToServer(Severity.Debug, "WriteFile: Cannot write '{0}'. Data Error.", openedFilePath);
+                        return NTStatus.STATUS_DATA_ERROR;
                     }
                 }
                 catch (ArgumentOutOfRangeException)
                 {
-                    header.Status = NTStatus.STATUS_DATA_ERROR;
-                    return 0;
+                    state.LogToServer(Severity.Debug, "WriteFile: Cannot write '{0}'. Offset Out Of Range.", openedFilePath);
+                    return NTStatus.STATUS_DATA_ERROR;
                 }
                 catch (UnauthorizedAccessException)
                 {
+                    state.LogToServer(Severity.Debug, "WriteFile: Cannot write '{0}'. Access Denied.", openedFilePath);
                     // The user may have tried to write to a readonly file
-                    header.Status = NTStatus.STATUS_ACCESS_DENIED;
-                    return 0;
+                    return NTStatus.STATUS_ACCESS_DENIED;
                 }
             }
         }