소스 검색

Windows Authentication: Improved implementation

Tal Aloni 8 년 전
부모
커밋
d7e33e465d
2개의 변경된 파일78개의 추가작업 그리고 65개의 파일을 삭제
  1. 44 18
      SMBLibrary/Win32/Authentication/SSPIHelper.cs
  2. 34 47
      SMBLibrary/Win32/Authentication/SecBufferDesc.cs

+ 44 - 18
SMBLibrary/Win32/Authentication/SSPIHelper.cs

@@ -147,7 +147,7 @@ namespace SMBLibrary.Authentication.Win32
         );
 
         [DllImport("Secur32.dll")]
-        private extern static int DeleteSecurityContext(
+        public extern static int DeleteSecurityContext(
             ref SecHandle phContext
         );
 
@@ -205,13 +205,14 @@ namespace SMBLibrary.Authentication.Win32
 
         public static byte[] GetType1Message(string domainName, string userName, string password, out SecHandle clientContext)
         {
-            SecHandle handle = AcquireNTLMCredentialsHandle(domainName, userName, password);
+            SecHandle credentialsHandle = AcquireNTLMCredentialsHandle(domainName, userName, password);
             clientContext = new SecHandle();
-            SecBufferDesc output = new SecBufferDesc(MAX_TOKEN_SIZE);
+            SecBuffer outputBuffer = new SecBuffer(MAX_TOKEN_SIZE);
+            SecBufferDesc output = new SecBufferDesc(outputBuffer);
             uint contextAttributes;
             SECURITY_INTEGER expiry;
 
-            int result = InitializeSecurityContext(ref handle, IntPtr.Zero, null, ISC_REQ_CONFIDENTIALITY | ISC_REQ_INTEGRITY, 0, SECURITY_NATIVE_DREP, IntPtr.Zero, 0, ref clientContext, ref output, out contextAttributes, out expiry);
+            int result = InitializeSecurityContext(ref credentialsHandle, IntPtr.Zero, null, ISC_REQ_CONFIDENTIALITY | ISC_REQ_INTEGRITY, 0, SECURITY_NATIVE_DREP, IntPtr.Zero, 0, ref clientContext, ref output, out contextAttributes, out expiry);
             if (result != SEC_E_OK && result != SEC_I_CONTINUE_NEEDED)
             {
                 if ((uint)result == SEC_E_INVALID_HANDLE)
@@ -227,14 +228,20 @@ namespace SMBLibrary.Authentication.Win32
                     throw new Exception("InitializeSecurityContext failed, Error code " + ((uint)result).ToString("X"));
                 }
             }
-            return output.GetSecBufferBytes();
+            FreeCredentialsHandle(ref credentialsHandle);
+            byte[] messageBytes = outputBuffer.GetBufferBytes();
+            outputBuffer.Dispose();
+            output.Dispose();
+            return messageBytes;
         }
 
         public static byte[] GetType3Message(SecHandle clientContext, byte[] type2Message)
         {
             SecHandle newContext = new SecHandle();
-            SecBufferDesc input = new SecBufferDesc(type2Message);
-            SecBufferDesc output = new SecBufferDesc(MAX_TOKEN_SIZE);
+            SecBuffer inputBuffer = new SecBuffer(type2Message);
+            SecBufferDesc input = new SecBufferDesc(inputBuffer);
+            SecBuffer outputBuffer = new SecBuffer(MAX_TOKEN_SIZE);
+            SecBufferDesc output = new SecBufferDesc(outputBuffer);
             uint contextAttributes;
             SECURITY_INTEGER expiry;
 
@@ -254,19 +261,26 @@ namespace SMBLibrary.Authentication.Win32
                     throw new Exception("InitializeSecurityContext failed, error code " + ((uint)result).ToString("X"));
                 }
             }
-            return output.GetSecBufferBytes();
+            byte[] messageBytes = outputBuffer.GetBufferBytes();
+            inputBuffer.Dispose();
+            input.Dispose();
+            outputBuffer.Dispose();
+            output.Dispose();
+            return messageBytes;
         }
 
         public static byte[] GetType2Message(byte[] type1MessageBytes, out SecHandle serverContext)
         {
-            SecHandle handle = AcquireNTLMCredentialsHandle();
-            SecBufferDesc type1Message = new SecBufferDesc(type1MessageBytes);
+            SecHandle credentialsHandle = AcquireNTLMCredentialsHandle();
+            SecBuffer inputBuffer = new SecBuffer(type1MessageBytes);
+            SecBufferDesc input = new SecBufferDesc(inputBuffer);
             serverContext = new SecHandle();
-            SecBufferDesc output = new SecBufferDesc(MAX_TOKEN_SIZE);
+            SecBuffer outputBuffer = new SecBuffer(MAX_TOKEN_SIZE);
+            SecBufferDesc output = new SecBufferDesc(outputBuffer);
             uint contextAttributes;
             SECURITY_INTEGER timestamp;
 
-            int result = AcceptSecurityContext(ref handle, IntPtr.Zero, ref type1Message, ASC_REQ_INTEGRITY | ASC_REQ_CONFIDENTIALITY, SECURITY_NATIVE_DREP, ref serverContext, ref output, out contextAttributes, out timestamp);
+            int result = AcceptSecurityContext(ref credentialsHandle, IntPtr.Zero, ref input, ASC_REQ_INTEGRITY | ASC_REQ_CONFIDENTIALITY, SECURITY_NATIVE_DREP, ref serverContext, ref output, out contextAttributes, out timestamp);
             if (result != SEC_E_OK && result != SEC_I_CONTINUE_NEEDED)
             {
                 if ((uint)result == SEC_E_INVALID_HANDLE)
@@ -282,8 +296,13 @@ namespace SMBLibrary.Authentication.Win32
                     throw new Exception("AcceptSecurityContext failed, error code " + ((uint)result).ToString("X"));
                 }
             }
-            FreeCredentialsHandle(ref handle);
-            return output.GetSecBufferBytes();
+            FreeCredentialsHandle(ref credentialsHandle);
+            byte[] messageBytes = outputBuffer.GetBufferBytes();
+            inputBuffer.Dispose();
+            input.Dispose();
+            outputBuffer.Dispose();
+            output.Dispose();
+            return messageBytes;
         }
 
         /// <summary>
@@ -303,13 +322,20 @@ namespace SMBLibrary.Authentication.Win32
         public static bool AuthenticateType3Message(SecHandle serverContext, byte[] type3MessageBytes)
         {
             SecHandle newContext = new SecHandle();
-            SecBufferDesc type3Message = new SecBufferDesc(type3MessageBytes);
-            SecBufferDesc output = new SecBufferDesc(MAX_TOKEN_SIZE);
+            SecBuffer inputBuffer = new SecBuffer(type3MessageBytes);
+            SecBufferDesc input = new SecBufferDesc(inputBuffer);
+            SecBuffer outputBuffer = new SecBuffer(MAX_TOKEN_SIZE);
+            SecBufferDesc output = new SecBufferDesc(outputBuffer);
             uint contextAttributes;
             SECURITY_INTEGER timestamp;
 
-            int result = AcceptSecurityContext(IntPtr.Zero, ref serverContext, ref type3Message, ASC_REQ_INTEGRITY | ASC_REQ_CONFIDENTIALITY, SECURITY_NATIVE_DREP, ref newContext, ref output, out contextAttributes, out timestamp);
-            
+            int result = AcceptSecurityContext(IntPtr.Zero, ref serverContext, ref input, ASC_REQ_INTEGRITY | ASC_REQ_CONFIDENTIALITY, SECURITY_NATIVE_DREP, ref newContext, ref output, out contextAttributes, out timestamp);
+
+            inputBuffer.Dispose();
+            input.Dispose();
+            outputBuffer.Dispose();
+            output.Dispose();
+
             if (result == SEC_E_OK)
             {
                 return true;

+ 34 - 47
SMBLibrary/Win32/Authentication/SecBufferDesc.cs

@@ -1,4 +1,4 @@
-/* Copyright (C) 2014 Tal Aloni <tal.aloni.il@gmail.com>. All rights reserved.
+/* Copyright (C) 2014-2017 Tal Aloni <tal.aloni.il@gmail.com>. All rights reserved.
  * 
  * You can redistribute this program and/or modify it under the terms of
  * the GNU Lesser Public License as published by the Free Software Foundation,
@@ -11,7 +11,7 @@ using System.Text;
 
 namespace SMBLibrary.Authentication.Win32
 {
-    public enum SecBufferType
+    public enum SecBufferType : uint
     {
         SECBUFFER_VERSION = 0,
         SECBUFFER_EMPTY = 0,
@@ -20,33 +20,33 @@ namespace SMBLibrary.Authentication.Win32
     }
 
     [StructLayout(LayoutKind.Sequential)]
-    public struct SecBuffer
+    public struct SecBuffer : IDisposable
     {
-        public int cbBuffer;
-        public int BufferType;
-        public IntPtr pvBuffer;
+        public uint cbBuffer;    // Specifies the size, in bytes, of the buffer pointed to by the pvBuffer member.
+        public uint BufferType;
+        public IntPtr pvBuffer; // A pointer to a buffer.
 
         public SecBuffer(int bufferSize)
         {
-            cbBuffer = bufferSize;
-            BufferType = (int)SecBufferType.SECBUFFER_TOKEN;
+            cbBuffer = (uint)bufferSize;
+            BufferType = (uint)SecBufferType.SECBUFFER_TOKEN;
             pvBuffer = Marshal.AllocHGlobal(bufferSize);
         }
 
         public SecBuffer(byte[] secBufferBytes)
         {
-            cbBuffer = secBufferBytes.Length;
-            BufferType = (int)SecBufferType.SECBUFFER_TOKEN;
-            pvBuffer = Marshal.AllocHGlobal(cbBuffer);
-            Marshal.Copy(secBufferBytes, 0, pvBuffer, cbBuffer);
+            cbBuffer = (uint)secBufferBytes.Length;
+            BufferType = (uint)SecBufferType.SECBUFFER_TOKEN;
+            pvBuffer = Marshal.AllocHGlobal(secBufferBytes.Length);
+            Marshal.Copy(secBufferBytes, 0, pvBuffer, secBufferBytes.Length);
         }
 
         public SecBuffer(byte[] secBufferBytes, SecBufferType bufferType)
         {
-            cbBuffer = secBufferBytes.Length;
-            BufferType = (int)bufferType;
-            pvBuffer = Marshal.AllocHGlobal(cbBuffer);
-            Marshal.Copy(secBufferBytes, 0, pvBuffer, cbBuffer);
+            cbBuffer = (uint)secBufferBytes.Length;
+            BufferType = (uint)bufferType;
+            pvBuffer = Marshal.AllocHGlobal(secBufferBytes.Length);
+            Marshal.Copy(secBufferBytes, 0, pvBuffer, secBufferBytes.Length);
         }
 
         public void Dispose()
@@ -58,63 +58,50 @@ namespace SMBLibrary.Authentication.Win32
             }
         }
 
-        public byte[] GetBytes()
+        public byte[] GetBufferBytes()
         {
             byte[] buffer = null;
             if (cbBuffer > 0)
             {
                 buffer = new byte[cbBuffer];
-                Marshal.Copy(pvBuffer, buffer, 0, cbBuffer);
+                Marshal.Copy(pvBuffer, buffer, 0, (int)cbBuffer);
             }
             return buffer;
         }
     }
 
-    /// <summary>
-    /// Simplified SecBufferDesc struct with only one SecBuffer
-    /// </summary>
     [StructLayout(LayoutKind.Sequential)]
-    public struct SecBufferDesc
+    public struct SecBufferDesc : IDisposable
     {
-        public int ulVersion;
-        public int cBuffers;
-        public IntPtr pBuffers;
+        public uint ulVersion;
+        public uint cBuffers;    // Indicates the number of SecBuffer structures in the pBuffers array.
+        public IntPtr pBuffers; // Pointer to an array of SecBuffer structures.
 
-        public SecBufferDesc(int bufferSize)
+        public SecBufferDesc(SecBuffer buffer) : this(new SecBuffer[] { buffer })
         {
-            ulVersion = (int)SecBufferType.SECBUFFER_VERSION;
-            cBuffers = 1;
-            SecBuffer secBuffer = new SecBuffer(bufferSize);
-            pBuffers = Marshal.AllocHGlobal(Marshal.SizeOf(secBuffer));
-            Marshal.StructureToPtr(secBuffer, pBuffers, false);
         }
 
-        public SecBufferDesc(byte[] secBufferBytes)
+        public SecBufferDesc(SecBuffer[] buffers)
         {
-            ulVersion = (int)SecBufferType.SECBUFFER_VERSION;
-            cBuffers = 1;
-            SecBuffer secBuffer = new SecBuffer(secBufferBytes);
-            pBuffers = Marshal.AllocHGlobal(Marshal.SizeOf(secBuffer));
-            Marshal.StructureToPtr(secBuffer, pBuffers, false);
+            int secBufferSize = Marshal.SizeOf(typeof(SecBuffer));
+            ulVersion = (uint)SecBufferType.SECBUFFER_VERSION;
+            cBuffers = (uint)buffers.Length;
+            pBuffers = Marshal.AllocHGlobal(buffers.Length * secBufferSize);
+            IntPtr currentBuffer = pBuffers;
+            for (int index = 0; index < buffers.Length; index++)
+            {
+                Marshal.StructureToPtr(buffers[index], currentBuffer, false);
+                currentBuffer = new IntPtr(currentBuffer.ToInt64() + secBufferSize);
+            }
         }
 
         public void Dispose()
         {
             if (pBuffers != IntPtr.Zero)
             {
-                SecBuffer secBuffer = (SecBuffer)Marshal.PtrToStructure(pBuffers, typeof(SecBuffer));
-                secBuffer.Dispose();
                 Marshal.FreeHGlobal(pBuffers);
                 pBuffers = IntPtr.Zero;
             }
         }
-
-        public byte[] GetSecBufferBytes()
-        {
-            if (pBuffers == IntPtr.Zero)
-                throw new ObjectDisposedException("SecBufferDesc");
-            SecBuffer secBuffer = (SecBuffer)Marshal.PtrToStructure(pBuffers, typeof(SecBuffer));
-            return secBuffer.GetBytes();
-        }
     }
 }