Przeglądaj źródła

Client: Improved NetBios over TCP support

Tal Aloni 5 lat temu
rodzic
commit
0a89a84434
2 zmienionych plików z 166 dodań i 43 usunięć
  1. 83 25
      SMBLibrary/Client/SMB1Client.cs
  2. 83 18
      SMBLibrary/Client/SMB2Client.cs

+ 83 - 25
SMBLibrary/Client/SMB1Client.cs

@@ -1,4 +1,4 @@
-/* Copyright (C) 2014-2019 Tal Aloni <tal.aloni.il@gmail.com>. All rights reserved.
+/* Copyright (C) 2014-2020 Tal Aloni <tal.aloni.il@gmail.com>. All rights reserved.
  * 
  * You can redistribute this program and/or modify it under the terms of
  * the GNU Lesser Public License as published by the Free Software Foundation,
@@ -44,6 +44,9 @@ namespace SMBLibrary.Client
         private List<SMB1Message> m_incomingQueue = new List<SMB1Message>();
         private EventWaitHandle m_incomingQueueEventHandle = new EventWaitHandle(false, EventResetMode.AutoReset);
 
+        private SessionPacket m_sessionResponsePacket;
+        private EventWaitHandle m_sessionResponseEventHandle = new EventWaitHandle(false, EventResetMode.AutoReset);
+
         private ushort m_userID;
         private byte[] m_serverChallenge;
         private byte[] m_securityBlob;
@@ -64,29 +67,55 @@ namespace SMBLibrary.Client
             if (!m_isConnected)
             {
                 m_forceExtendedSecurity = forceExtendedSecurity;
-                m_clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
                 int port;
-                if (transport == SMBTransportType.DirectTCPTransport)
+                if (transport == SMBTransportType.NetBiosOverTCP)
                 {
-                    port = DirectTCPPort;
+                    port = NetBiosOverTCPPort;
                 }
                 else
                 {
-                    port = NetBiosOverTCPPort;
+                    port = DirectTCPPort;
                 }
 
-                try
+                if (!ConnectSocket(serverAddress, port))
                 {
-                    m_clientSocket.Connect(serverAddress, port);
+                    return false;
                 }
-                catch (SocketException)
+                
+                if (transport == SMBTransportType.NetBiosOverTCP)
                 {
-                    return false;
+                    SessionRequestPacket sessionRequest = new SessionRequestPacket();
+                    sessionRequest.CalledName = NetBiosUtils.GetMSNetBiosName("*SMBSERVER", NetBiosSuffix.FileServiceService);
+                    sessionRequest.CallingName = NetBiosUtils.GetMSNetBiosName(Environment.MachineName, NetBiosSuffix.WorkstationService);
+                    TrySendPacket(m_clientSocket, sessionRequest);
+
+                    SessionPacket sessionResponsePacket = WaitForSessionResponsePacket();
+                    if (!(sessionResponsePacket is PositiveSessionResponsePacket))
+                    {
+                        m_clientSocket.Close();
+                        if (!ConnectSocket(serverAddress, port))
+                        {
+                            return false;
+                        }
+
+                        NameServiceClient nameServiceClient = new NameServiceClient(serverAddress);
+                        string serverName = nameServiceClient.GetServerName();
+                        if (serverName == null)
+                        {
+                            return false;
+                        }
+
+                        sessionRequest.CalledName = serverName;
+                        TrySendPacket(m_clientSocket, sessionRequest);
+
+                        sessionResponsePacket = WaitForSessionResponsePacket();
+                        if (!(sessionResponsePacket is PositiveSessionResponsePacket))
+                        {
+                            return false;
+                        }
+                    }
                 }
 
-                ConnectionState state = new ConnectionState();
-                NBTConnectionReceiveBuffer buffer = state.ReceiveBuffer;
-                m_clientSocket.BeginReceive(buffer.Buffer, buffer.WriteOffset, buffer.AvailableLength, SocketFlags.None, new AsyncCallback(OnClientSocketReceive), state);
                 bool supportsDialect = NegotiateDialect(m_forceExtendedSecurity);
                 if (!supportsDialect)
                 {
@@ -100,6 +129,25 @@ namespace SMBLibrary.Client
             return m_isConnected;
         }
 
+        private bool ConnectSocket(IPAddress serverAddress, int port)
+        {
+            m_clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+            
+            try
+            {
+                m_clientSocket.Connect(serverAddress, port);
+            }
+            catch (SocketException)
+            {
+                return false;
+            }
+
+            ConnectionState state = new ConnectionState();
+            NBTConnectionReceiveBuffer buffer = state.ReceiveBuffer;
+            m_clientSocket.BeginReceive(buffer.Buffer, buffer.WriteOffset, buffer.AvailableLength, SocketFlags.None, new AsyncCallback(OnClientSocketReceive), state);
+            return true;
+        }
+
         public void Disconnect()
         {
             if (m_isConnected)
@@ -111,13 +159,6 @@ namespace SMBLibrary.Client
 
         private bool NegotiateDialect(bool forceExtendedSecurity)
         {
-            if (m_transport == SMBTransportType.NetBiosOverTCP)
-            {
-                SessionRequestPacket sessionRequest = new SessionRequestPacket();
-                sessionRequest.CalledName = NetBiosUtils.GetMSNetBiosName("*SMBSERVER", NetBiosSuffix.FileServiceService); ;
-                sessionRequest.CallingName = NetBiosUtils.GetMSNetBiosName(Environment.MachineName, NetBiosSuffix.WorkstationService);
-                TrySendPacket(m_clientSocket, sessionRequest);
-            }
             NegotiateRequest request = new NegotiateRequest();
             request.Dialects.Add(NTLanManagerDialect);
 
@@ -440,13 +481,10 @@ namespace SMBLibrary.Client
             {
                 // [RFC 1001] NetBIOS session keep alives do not require a response from the NetBIOS peer
             }
-            else if (packet is PositiveSessionResponsePacket && m_transport == SMBTransportType.NetBiosOverTCP)
-            {
-            }
-            else if (packet is NegativeSessionResponsePacket && m_transport == SMBTransportType.NetBiosOverTCP)
+            else if ((packet is PositiveSessionResponsePacket || packet is NegativeSessionResponsePacket) && m_transport == SMBTransportType.NetBiosOverTCP)
             {
-                m_clientSocket.Close();
-                m_isConnected = false;
+                m_sessionResponsePacket = packet;
+                m_sessionResponseEventHandle.Set();
             }
             else if (packet is SessionMessagePacket)
             {
@@ -503,6 +541,26 @@ namespace SMBLibrary.Client
             return null;
         }
 
+        internal SessionPacket WaitForSessionResponsePacket()
+        {
+            const int TimeOut = 5000;
+            Stopwatch stopwatch = new Stopwatch();
+            stopwatch.Start();
+            while (stopwatch.ElapsedMilliseconds < TimeOut)
+            {
+                if (m_sessionResponsePacket != null)
+                {
+                    SessionPacket result = m_sessionResponsePacket;
+                    m_sessionResponsePacket = null;
+                    return result;
+                }
+
+                m_sessionResponseEventHandle.WaitOne(100);
+            }
+
+            return null;
+        }
+
         private void Log(string message)
         {
             System.Diagnostics.Debug.Print(message);

+ 83 - 18
SMBLibrary/Client/SMB2Client.cs

@@ -1,4 +1,4 @@
-/* Copyright (C) 2017-2019 Tal Aloni <tal.aloni.il@gmail.com>. All rights reserved.
+/* Copyright (C) 2017-2020 Tal Aloni <tal.aloni.il@gmail.com>. All rights reserved.
  * 
  * You can redistribute this program and/or modify it under the terms of
  * the GNU Lesser Public License as published by the Free Software Foundation,
@@ -37,6 +37,9 @@ namespace SMBLibrary.Client
         private List<SMB2Command> m_incomingQueue = new List<SMB2Command>();
         private EventWaitHandle m_incomingQueueEventHandle = new EventWaitHandle(false, EventResetMode.AutoReset);
 
+        private SessionPacket m_sessionResponsePacket;
+        private EventWaitHandle m_sessionResponseEventHandle = new EventWaitHandle(false, EventResetMode.AutoReset);
+
         private uint m_messageID = 0;
         private SMB2Dialect m_dialect;
         private bool m_signingRequired;
@@ -56,29 +59,55 @@ namespace SMBLibrary.Client
             m_transport = transport;
             if (!m_isConnected)
             {
-                m_clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
                 int port;
-                if (transport == SMBTransportType.DirectTCPTransport)
+                if (transport == SMBTransportType.NetBiosOverTCP)
                 {
-                    port = DirectTCPPort;
+                    port = NetBiosOverTCPPort;
                 }
                 else
                 {
-                    port = NetBiosOverTCPPort;
+                    port = DirectTCPPort;
                 }
 
-                try
+                if (!ConnectSocket(serverAddress, port))
                 {
-                    m_clientSocket.Connect(serverAddress, port);
+                    return false;
                 }
-                catch (SocketException)
+
+                if (transport == SMBTransportType.NetBiosOverTCP)
                 {
-                    return false;
+                    SessionRequestPacket sessionRequest = new SessionRequestPacket();
+                    sessionRequest.CalledName = NetBiosUtils.GetMSNetBiosName("*SMBSERVER", NetBiosSuffix.FileServiceService);
+                    sessionRequest.CallingName = NetBiosUtils.GetMSNetBiosName(Environment.MachineName, NetBiosSuffix.WorkstationService);
+                    TrySendPacket(m_clientSocket, sessionRequest);
+
+                    SessionPacket sessionResponsePacket = WaitForSessionResponsePacket();
+                    if (!(sessionResponsePacket is PositiveSessionResponsePacket))
+                    {
+                        m_clientSocket.Close();
+                        if (!ConnectSocket(serverAddress, port))
+                        {
+                            return false;
+                        }
+
+                        NameServiceClient nameServiceClient = new NameServiceClient(serverAddress);
+                        string serverName = nameServiceClient.GetServerName();
+                        if (serverName == null)
+                        {
+                            return false;
+                        }
+
+                        sessionRequest.CalledName = serverName;
+                        TrySendPacket(m_clientSocket, sessionRequest);
+
+                        sessionResponsePacket = WaitForSessionResponsePacket();
+                        if (!(sessionResponsePacket is PositiveSessionResponsePacket))
+                        {
+                            return false;
+                        }
+                    }
                 }
 
-                ConnectionState state = new ConnectionState();
-                NBTConnectionReceiveBuffer buffer = state.ReceiveBuffer;
-                m_clientSocket.BeginReceive(buffer.Buffer, buffer.WriteOffset, buffer.AvailableLength, SocketFlags.None, new AsyncCallback(OnClientSocketReceive), state);
                 bool supportsDialect = NegotiateDialect();
                 if (!supportsDialect)
                 {
@@ -92,6 +121,25 @@ namespace SMBLibrary.Client
             return m_isConnected;
         }
 
+        private bool ConnectSocket(IPAddress serverAddress, int port)
+        {
+            m_clientSocket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);
+
+            try
+            {
+                m_clientSocket.Connect(serverAddress, port);
+            }
+            catch (SocketException)
+            {
+                return false;
+            }
+
+            ConnectionState state = new ConnectionState();
+            NBTConnectionReceiveBuffer buffer = state.ReceiveBuffer;
+            m_clientSocket.BeginReceive(buffer.Buffer, buffer.WriteOffset, buffer.AvailableLength, SocketFlags.None, new AsyncCallback(OnClientSocketReceive), state);
+            return true;
+        }
+
         public void Disconnect()
         {
             if (m_isConnected)
@@ -328,13 +376,10 @@ namespace SMBLibrary.Client
             {
                 // [RFC 1001] NetBIOS session keep alives do not require a response from the NetBIOS peer
             }
-            else if (packet is PositiveSessionResponsePacket && m_transport == SMBTransportType.NetBiosOverTCP)
-            {
-            }
-            else if (packet is NegativeSessionResponsePacket && m_transport == SMBTransportType.NetBiosOverTCP)
+            else if ((packet is PositiveSessionResponsePacket || packet is NegativeSessionResponsePacket) && m_transport == SMBTransportType.NetBiosOverTCP)
             {
-                m_clientSocket.Close();
-                m_isConnected = false;
+                m_sessionResponsePacket = packet;
+                m_sessionResponseEventHandle.Set();
             }
             else if (packet is SessionMessagePacket)
             {
@@ -391,6 +436,26 @@ namespace SMBLibrary.Client
             return null;
         }
 
+        internal SessionPacket WaitForSessionResponsePacket()
+        {
+            const int TimeOut = 5000;
+            Stopwatch stopwatch = new Stopwatch();
+            stopwatch.Start();
+            while (stopwatch.ElapsedMilliseconds < TimeOut)
+            {
+                if (m_sessionResponsePacket != null)
+                {
+                    SessionPacket result = m_sessionResponsePacket;
+                    m_sessionResponsePacket = null;
+                    return result;
+                }
+
+                m_sessionResponseEventHandle.WaitOne(100);
+            }
+
+            return null;
+        }
+
         private void Log(string message)
         {
             System.Diagnostics.Debug.Print(message);