Browse Source

Improved session management logic

Tal Aloni 8 years ago
parent
commit
049cb5b104

+ 1 - 0
SMBLibrary/Enums/NTStatus.cs

@@ -29,6 +29,7 @@ namespace SMBLibrary
         STATUS_TOO_MANY_SESSIONS = 0xC00000CE,
         STATUS_TOO_MANY_OPENED_FILES = 0xC000011F,
         STATUS_CANNOT_DELETE = 0xC0000121,
+        STATUS_USER_SESSION_DELETED = 0xC0000203,
         STATUS_INSUFF_SERVER_RESOURCES = 0xC0000205,
         
         STATUS_INVALID_SMB = 0x00010002,     // CIFS/SMB1: A corrupt or invalid SMB request was received

+ 1 - 0
SMBLibrary/SMBLibrary.csproj

@@ -108,6 +108,7 @@
     <Compile Include="Server\ConnectionState\OpenFileObject.cs" />
     <Compile Include="Server\ConnectionState\ProcessStateObject.cs" />
     <Compile Include="Server\ConnectionState\SMB1ConnectionState.cs" />
+    <Compile Include="Server\ConnectionState\SMB1Session.cs" />
     <Compile Include="Server\Exceptions\EmptyPasswordNotAllowedException.cs" />
     <Compile Include="Server\Exceptions\InvalidRequestException.cs" />
     <Compile Include="Server\Exceptions\UnsupportedInformationLevelException.cs" />

+ 57 - 147
SMBLibrary/Server/ConnectionState/SMB1ConnectionState.cs

@@ -18,24 +18,14 @@ namespace SMBLibrary.Server
         public bool LargeWrite;
 
         // Key is UID
-        private Dictionary<ushort, string> m_connectedUsers = new Dictionary<ushort, string>();
-        private ushort m_nextUID = 1;
+        private Dictionary<ushort, SMB1Session> m_sessions = new Dictionary<ushort, SMB1Session>();
+        private ushort m_nextUID = 1; // UID MUST be unique within an SMB connection
+        private ushort m_nextTID = 1; // TID MUST be unique within an SMB connection
+        private ushort m_nextFID = 1; // FID MUST be unique within an SMB connection
 
-        // Key is TID
-        private Dictionary<ushort, ISMBShare> m_connectedTrees = new Dictionary<ushort, ISMBShare>();
-        private ushort m_nextTID = 1;
-
-        // Key is FID
-        private Dictionary<ushort, OpenFileObject> m_openFiles = new Dictionary<ushort, OpenFileObject>();
-        private ushort m_nextFID = 1;
-
-        // Key is PID
+        // Key is PID (PID MUST be unique within an SMB connection)
         private Dictionary<uint, ProcessStateObject> m_processStateList = new Dictionary<uint, ProcessStateObject>();
 
-        private const int MaxSearches = 2048; // Windows servers initialize Server.MaxSearches to 2048.
-        public Dictionary<ushort, List<FileSystemEntry>> OpenSearches = new Dictionary<ushort, List<FileSystemEntry>>();
-        private ushort m_nextSearchHandle = 1;
-
         public SMB1ConnectionState(ConnectionState state) : base(state)
         {
         }
@@ -44,7 +34,7 @@ namespace SMBLibrary.Server
         /// An open UID MUST be unique within an SMB connection.
         /// The value of 0xFFFE SHOULD NOT be used as a valid UID. All other possible values for a UID, excluding zero (0x0000), are valid.
         /// </summary>
-        private ushort? AllocateUserID()
+        public ushort? AllocateUserID()
         {
             for (ushort offset = 0; offset < UInt16.MaxValue; offset++)
             {
@@ -53,7 +43,7 @@ namespace SMBLibrary.Server
                 {
                     continue;
                 }
-                if (!m_connectedUsers.ContainsKey(userID))
+                if (!m_sessions.ContainsKey(userID))
                 {
                     m_nextUID = (ushort)(userID + 1);
                     return userID;
@@ -62,43 +52,41 @@ namespace SMBLibrary.Server
             return null;
         }
 
-        public ushort? AddConnectedUser(string userName)
+        public SMB1Session CreateSession(ushort userID, string userName)
         {
-            ushort? userID = AllocateUserID();
-            if (userID.HasValue)
-            {
-                m_connectedUsers.Add(userID.Value, userName);
-            }
-            return userID;
+            SMB1Session session = new SMB1Session(this, userID, userName);
+            m_sessions.Add(userID, session);
+            return session;
         }
 
-        public string GetConnectedUserName(ushort userID)
+        /// <returns>null if all UserID values have already been allocated</returns>
+        public SMB1Session CreateSession(string userName)
         {
-            if (m_connectedUsers.ContainsKey(userID))
-            {
-                return m_connectedUsers[userID];
-            }
-            else
+            ushort? userID = AllocateUserID();
+            if (userID.HasValue)
             {
-                return null;
+                return CreateSession(userID.Value, userName);
             }
+            return null;
         }
 
-        public bool IsAuthenticated(ushort userID)
+        public SMB1Session GetSession(ushort userID)
         {
-            return m_connectedUsers.ContainsKey(userID);
+            SMB1Session session;
+            m_sessions.TryGetValue(userID, out session);
+            return session;
         }
 
-        public void RemoveConnectedUser(ushort userID)
+        public void RemoveSession(ushort userID)
         {
-            m_connectedUsers.Remove(userID);
+            m_sessions.Remove(userID);
         }
 
         /// <summary>
         /// An open TID MUST be unique within an SMB connection.
         /// The value 0xFFFF MUST NOT be used as a valid TID. All other possible values for TID, including zero (0x0000), are valid.
         /// </summary>
-        private ushort? AllocateTreeID()
+        public ushort? AllocateTreeID()
         {
             for (ushort offset = 0; offset < UInt16.MaxValue; offset++)
             {
@@ -107,7 +95,7 @@ namespace SMBLibrary.Server
                 {
                     continue;
                 }
-                if (!m_connectedTrees.ContainsKey(treeID))
+                if (!IsTreeIDAllocated(treeID))
                 {
                     m_nextTID = (ushort)(treeID + 1);
                     return treeID;
@@ -116,72 +104,24 @@ namespace SMBLibrary.Server
             return null;
         }
 
-        public ushort? AddConnectedTree(ISMBShare share)
-        {
-            ushort? treeID = AllocateTreeID();
-            if (treeID.HasValue)
-            {
-                m_connectedTrees.Add(treeID.Value, share);
-            }
-            return treeID;
-        }
-
-        public ISMBShare GetConnectedTree(ushort treeID)
-        {
-            if (m_connectedTrees.ContainsKey(treeID))
-            {
-                return m_connectedTrees[treeID];
-            }
-            else
-            {
-                return null;
-            }
-        }
-
-        public void RemoveConnectedTree(ushort treeID)
-        {
-            m_connectedTrees.Remove(treeID);
-        }
-
-        public bool IsTreeConnected(ushort treeID)
-        {
-            return m_connectedTrees.ContainsKey(treeID);
-        }
-
-        public ProcessStateObject GetProcessState(uint processID)
-        {
-            if (m_processStateList.ContainsKey(processID))
-            {
-                return m_processStateList[processID];
-            }
-            else
-            {
-                return null;
-            }
-        }
-
-        /// <summary>
-        /// Get or Create process state
-        /// </summary>
-        public ProcessStateObject ObtainProcessState(uint processID)
+        private bool IsTreeIDAllocated(ushort treeID)
         {
-            if (m_processStateList.ContainsKey(processID))
+            foreach (SMB1Session session in m_sessions.Values)
             {
-                return m_processStateList[processID];
-            }
-            else
-            {
-                ProcessStateObject processState = new ProcessStateObject();
-                m_processStateList[processID] = processState;
-                return processState;
+                if (session.GetConnectedTree(treeID) != null)
+                {
+                    return true;
+                }
             }
+            return false;
         }
 
         /// <summary>
+        /// A FID returned from an Open or Create operation MUST be unique within an SMB connection.
         /// The value 0xFFFF MUST NOT be used as a valid FID. All other possible values for FID, including zero (0x0000) are valid.
         /// </summary>
         /// <returns></returns>
-        private ushort? AllocateFileID()
+        public ushort? AllocateFileID()
         {
             for (ushort offset = 0; offset < UInt16.MaxValue; offset++)
             {
@@ -190,7 +130,7 @@ namespace SMBLibrary.Server
                 {
                     continue;
                 }
-                if (!m_openFiles.ContainsKey(fileID))
+                if (!IsFileIDAllocated(fileID))
                 {
                     m_nextFID = (ushort)(fileID + 1);
                     return fileID;
@@ -199,33 +139,23 @@ namespace SMBLibrary.Server
             return null;
         }
 
-        /// <param name="relativePath">Should include the path relative to the share</param>
-        /// <returns>FileID</returns>
-        public ushort? AddOpenFile(string relativePath)
-        {
-            return AddOpenFile(relativePath, null);
-        }
-
-        public ushort? AddOpenFile(string relativePath, Stream stream)
+        private bool IsFileIDAllocated(ushort fileID)
         {
-            return AddOpenFile(relativePath, stream, false);
-        }
-
-        public ushort? AddOpenFile(string relativePath, Stream stream, bool deleteOnClose)
-        {
-            ushort? fileID = AllocateFileID();
-            if (fileID.HasValue)
+            foreach (SMB1Session session in m_sessions.Values)
             {
-                m_openFiles.Add(fileID.Value, new OpenFileObject(relativePath, stream, deleteOnClose));
+                if (session.GetOpenFileObject(fileID) != null)
+                {
+                    return true;
+                }
             }
-            return fileID;
+            return false;
         }
 
-        public OpenFileObject GetOpenFileObject(ushort fileID)
+        public ProcessStateObject GetProcessState(uint processID)
         {
-            if (m_openFiles.ContainsKey(fileID))
+            if (m_processStateList.ContainsKey(processID))
             {
-                return m_openFiles[fileID];
+                return m_processStateList[processID];
             }
             else
             {
@@ -233,15 +163,21 @@ namespace SMBLibrary.Server
             }
         }
 
-        public void RemoveOpenFile(ushort fileID)
+        /// <summary>
+        /// Get or Create process state
+        /// </summary>
+        public ProcessStateObject ObtainProcessState(uint processID)
         {
-            Stream stream = m_openFiles[fileID].Stream;
-            if (stream != null)
+            if (m_processStateList.ContainsKey(processID))
             {
-                LogToServer(Severity.Verbose, "Closing file '{0}'", m_openFiles[fileID].Path);
-                stream.Close();
+                return m_processStateList[processID];
+            }
+            else
+            {
+                ProcessStateObject processState = new ProcessStateObject();
+                m_processStateList[processID] = processState;
+                return processState;
             }
-            m_openFiles.Remove(fileID);
         }
 
         public uint? GetMaxDataCount(uint processID)
@@ -256,31 +192,5 @@ namespace SMBLibrary.Server
                 return null;
             }
         }
-
-        public ushort? AllocateSearchHandle()
-        {
-            for (ushort offset = 0; offset < UInt16.MaxValue; offset++)
-            {
-                ushort searchHandle = (ushort)(m_nextSearchHandle + offset);
-                if (searchHandle == 0 || searchHandle == 0xFFFF)
-                {
-                    continue;
-                }
-                if (!OpenSearches.ContainsKey(searchHandle))
-                {
-                    m_nextSearchHandle = (ushort)(searchHandle + 1);
-                    return searchHandle;
-                }
-            }
-            return null;
-        }
-
-        public void ReleaseSearchHandle(ushort searchHandle)
-        {
-            if (OpenSearches.ContainsKey(searchHandle))
-            {
-                OpenSearches.Remove(searchHandle);
-            }
-        }
     }
 }

+ 147 - 0
SMBLibrary/Server/ConnectionState/SMB1Session.cs

@@ -0,0 +1,147 @@
+/* Copyright (C) 2014-2017 Tal Aloni <tal.aloni.il@gmail.com>. All rights reserved.
+ * 
+ * You can redistribute this program and/or modify it under the terms of
+ * the GNU Lesser Public License as published by the Free Software Foundation,
+ * either version 3 of the License, or (at your option) any later version.
+ */
+using System;
+using System.Collections.Generic;
+using System.IO;
+using Utilities;
+
+namespace SMBLibrary.Server
+{
+    public class SMB1Session
+    {
+        private const int MaxSearches = 2048; // Windows servers initialize Server.MaxSearches to 2048.
+
+        private SMB1ConnectionState m_connection;
+        private ushort m_userID;
+        private string m_userName;
+
+        // Key is TID
+        private Dictionary<ushort, ISMBShare> m_connectedTrees = new Dictionary<ushort, ISMBShare>();
+
+        // Key is FID
+        private Dictionary<ushort, OpenFileObject> m_openFiles = new Dictionary<ushort, OpenFileObject>();
+
+        // Key is search handle a.k.a. Search ID
+        public Dictionary<ushort, List<FileSystemEntry>> OpenSearches = new Dictionary<ushort, List<FileSystemEntry>>();
+        private ushort m_nextSearchHandle = 1;
+
+        public SMB1Session(SMB1ConnectionState connection, ushort userID, string userName)
+        {
+            m_connection = connection;
+            m_userID = userID;
+            m_userName = userName;
+        }
+
+        public ushort? AddConnectedTree(ISMBShare share)
+        {
+            ushort? treeID = m_connection.AllocateTreeID();
+            if (treeID.HasValue)
+            {
+                m_connectedTrees.Add(treeID.Value, share);
+            }
+            return treeID;
+        }
+
+        public ISMBShare GetConnectedTree(ushort treeID)
+        {
+            ISMBShare share;
+            m_connectedTrees.TryGetValue(treeID, out share);
+            return share;
+        }
+
+        public void RemoveConnectedTree(ushort treeID)
+        {
+            m_connectedTrees.Remove(treeID);
+        }
+
+        public bool IsTreeConnected(ushort treeID)
+        {
+            return m_connectedTrees.ContainsKey(treeID);
+        }
+
+        /// <param name="relativePath">Should include the path relative to the share</param>
+        /// <returns>FileID</returns>
+        public ushort? AddOpenFile(string relativePath)
+        {
+            return AddOpenFile(relativePath, null);
+        }
+
+        public ushort? AddOpenFile(string relativePath, Stream stream)
+        {
+            return AddOpenFile(relativePath, stream, false);
+        }
+
+        public ushort? AddOpenFile(string relativePath, Stream stream, bool deleteOnClose)
+        {
+            ushort? fileID = m_connection.AllocateFileID();
+            if (fileID.HasValue)
+            {
+                m_openFiles.Add(fileID.Value, new OpenFileObject(relativePath, stream, deleteOnClose));
+            }
+            return fileID;
+        }
+
+        public OpenFileObject GetOpenFileObject(ushort fileID)
+        {
+            OpenFileObject openFile;
+            m_openFiles.TryGetValue(fileID, out openFile);
+            return openFile;
+        }
+
+        public void RemoveOpenFile(ushort fileID)
+        {
+            Stream stream = m_openFiles[fileID].Stream;
+            if (stream != null)
+            {
+                stream.Close();
+            }
+            m_openFiles.Remove(fileID);
+        }
+
+        public ushort? AllocateSearchHandle()
+        {
+            for (ushort offset = 0; offset < UInt16.MaxValue; offset++)
+            {
+                ushort searchHandle = (ushort)(m_nextSearchHandle + offset);
+                if (searchHandle == 0 || searchHandle == 0xFFFF)
+                {
+                    continue;
+                }
+                if (!OpenSearches.ContainsKey(searchHandle))
+                {
+                    m_nextSearchHandle = (ushort)(searchHandle + 1);
+                    return searchHandle;
+                }
+            }
+            return null;
+        }
+
+        public void ReleaseSearchHandle(ushort searchHandle)
+        {
+            if (OpenSearches.ContainsKey(searchHandle))
+            {
+                OpenSearches.Remove(searchHandle);
+            }
+        }
+
+        public ushort UserID
+        {
+            get
+            {
+                return m_userID;
+            }
+        }
+
+        public string UserName
+        {
+            get
+            {
+                return m_userName;
+            }
+        }
+    }
+}

+ 13 - 13
SMBLibrary/Server/SMB1/FileSystemResponseHelper.cs

@@ -17,8 +17,8 @@ namespace SMBLibrary.Server.SMB1
     {
         internal static SMB1Command GetCreateDirectoryResponse(SMB1Header header, CreateDirectoryRequest request, FileSystemShare share, SMB1ConnectionState state)
         {
-            string userName = state.GetConnectedUserName(header.UID);
-            if (!share.HasWriteAccess(userName))
+            SMB1Session session = state.GetSession(header.UID);
+            if (!share.HasWriteAccess(session.UserName))
             {
                 header.Status = NTStatus.STATUS_ACCESS_DENIED;
                 return new ErrorResponse(CommandName.SMB_COM_CREATE_DIRECTORY);
@@ -47,8 +47,8 @@ namespace SMBLibrary.Server.SMB1
 
         internal static SMB1Command GetDeleteDirectoryResponse(SMB1Header header, DeleteDirectoryRequest request, FileSystemShare share, SMB1ConnectionState state)
         {
-            string userName = state.GetConnectedUserName(header.UID);
-            if (!share.HasWriteAccess(userName))
+            SMB1Session session = state.GetSession(header.UID);
+            if (!share.HasWriteAccess(session.UserName))
             {
                 header.Status = NTStatus.STATUS_ACCESS_DENIED;
                 return new ErrorResponse(CommandName.SMB_COM_DELETE_DIRECTORY);
@@ -102,8 +102,8 @@ namespace SMBLibrary.Server.SMB1
 
         internal static SMB1Command GetDeleteResponse(SMB1Header header, DeleteRequest request, FileSystemShare share, SMB1ConnectionState state)
         {
-            string userName = state.GetConnectedUserName(header.UID);
-            if (!share.HasWriteAccess(userName))
+            SMB1Session session = state.GetSession(header.UID);
+            if (!share.HasWriteAccess(session.UserName))
             {
                 header.Status = NTStatus.STATUS_ACCESS_DENIED;
                 return new ErrorResponse(CommandName.SMB_COM_DELETE);
@@ -145,8 +145,8 @@ namespace SMBLibrary.Server.SMB1
 
         internal static SMB1Command GetRenameResponse(SMB1Header header, RenameRequest request, FileSystemShare share, SMB1ConnectionState state)
         {
-            string userName = state.GetConnectedUserName(header.UID);
-            if (!share.HasWriteAccess(userName))
+            SMB1Session session = state.GetSession(header.UID);
+            if (!share.HasWriteAccess(session.UserName))
             {
                 header.Status = NTStatus.STATUS_ACCESS_DENIED;
                 return new ErrorResponse(CommandName.SMB_COM_RENAME);
@@ -209,8 +209,8 @@ namespace SMBLibrary.Server.SMB1
 
         internal static SMB1Command GetSetInformationResponse(SMB1Header header, SetInformationRequest request, FileSystemShare share, SMB1ConnectionState state)
         {
-            string userName = state.GetConnectedUserName(header.UID);
-            if (!share.HasWriteAccess(userName))
+            SMB1Session session = state.GetSession(header.UID);
+            if (!share.HasWriteAccess(session.UserName))
             {
                 header.Status = NTStatus.STATUS_ACCESS_DENIED;
                 return new ErrorResponse(CommandName.SMB_COM_SET_INFORMATION2);
@@ -251,15 +251,15 @@ namespace SMBLibrary.Server.SMB1
 
         internal static SMB1Command GetSetInformation2Response(SMB1Header header, SetInformation2Request request, FileSystemShare share, SMB1ConnectionState state)
         {
-            OpenFileObject openFile = state.GetOpenFileObject(request.FID);
+            SMB1Session session = state.GetSession(header.UID);
+            OpenFileObject openFile = session.GetOpenFileObject(request.FID);
             if (openFile == null)
             {
                 header.Status = NTStatus.STATUS_SMB_BAD_FID;
                 return new ErrorResponse(CommandName.SMB_COM_SET_INFORMATION2);
             }
 
-            string userName = state.GetConnectedUserName(header.UID);
-            if (!share.HasWriteAccess(userName))
+            if (!share.HasWriteAccess(session.UserName))
             {
                 header.Status = NTStatus.STATUS_ACCESS_DENIED;
                 return new ErrorResponse(CommandName.SMB_COM_SET_INFORMATION2);

+ 4 - 3
SMBLibrary/Server/SMB1/NTCreateHelper.cs

@@ -18,6 +18,7 @@ namespace SMBLibrary.Server.SMB1
     {
         internal static SMB1Command GetNTCreateResponse(SMB1Header header, NTCreateAndXRequest request, ISMBShare share, SMB1ConnectionState state)
         {
+            SMB1Session session = state.GetSession(header.UID);
             bool isExtended = (request.Flags & NTCreateFlags.NT_CREATE_REQUEST_EXTENDED_RESPONSE) > 0;
             string path = request.FileName;
             if (share is NamedPipeShare)
@@ -25,7 +26,7 @@ namespace SMBLibrary.Server.SMB1
                 Stream pipeStream = ((NamedPipeShare)share).OpenPipe(path);
                 if (pipeStream != null)
                 {
-                    ushort? fileID = state.AddOpenFile(path, pipeStream);
+                    ushort? fileID = session.AddOpenFile(path, pipeStream);
                     if (!fileID.HasValue)
                     {
                         header.Status = NTStatus.STATUS_TOO_MANY_OPENED_FILES;
@@ -47,7 +48,7 @@ namespace SMBLibrary.Server.SMB1
             else // FileSystemShare
             {
                 FileSystemShare fileSystemShare = (FileSystemShare)share;
-                string userName = state.GetConnectedUserName(header.UID);
+                string userName = session.UserName;
                 FileSystemEntry entry;
                 NTStatus createStatus = CreateFile(out entry, fileSystemShare, userName, path, request.CreateDisposition, request.CreateOptions, request.DesiredAccess, state);
                 if (createStatus != NTStatus.STATUS_SUCCESS)
@@ -109,7 +110,7 @@ namespace SMBLibrary.Server.SMB1
                     }
                 }
 
-                ushort? fileID = state.AddOpenFile(path, stream, deleteOnClose);
+                ushort? fileID = session.AddOpenFile(path, stream, deleteOnClose);
                 if (!fileID.HasValue)
                 {
                     header.Status = NTStatus.STATUS_TOO_MANY_OPENED_FILES;

+ 4 - 3
SMBLibrary/Server/SMB1/OpenAndXHelper.cs

@@ -18,6 +18,7 @@ namespace SMBLibrary.Server.SMB1
     {
         internal static SMB1Command GetOpenAndXResponse(SMB1Header header, OpenAndXRequest request, ISMBShare share, SMB1ConnectionState state)
         {
+            SMB1Session session = state.GetSession(header.UID);
             bool isExtended = (request.Flags & OpenFlags.SMB_OPEN_EXTENDED_RESPONSE) > 0;
             string path = request.FileName;
             if (share is NamedPipeShare)
@@ -25,7 +26,7 @@ namespace SMBLibrary.Server.SMB1
                 Stream pipeStream = ((NamedPipeShare)share).OpenPipe(path);
                 if (pipeStream != null)
                 {
-                    ushort? fileID = state.AddOpenFile(path, pipeStream);
+                    ushort? fileID = session.AddOpenFile(path, pipeStream);
                     if (!fileID.HasValue)
                     {
                         header.Status = NTStatus.STATUS_TOO_MANY_OPENED_FILES;
@@ -47,7 +48,7 @@ namespace SMBLibrary.Server.SMB1
             else // FileSystemShare
             {
                 FileSystemShare fileSystemShare = (FileSystemShare)share;
-                string userName = state.GetConnectedUserName(header.UID);
+                string userName = session.UserName;
                 bool hasWriteAccess = fileSystemShare.HasWriteAccess(userName);
                 IFileSystem fileSystem = fileSystemShare.FileSystem;
 
@@ -133,7 +134,7 @@ namespace SMBLibrary.Server.SMB1
                         stream = new PrefetchedStream(stream);
                     }
                 }
-                ushort? fileID = state.AddOpenFile(path, stream);
+                ushort? fileID = session.AddOpenFile(path, stream);
                 if (!fileID.HasValue)
                 {
                     header.Status = NTStatus.STATUS_TOO_MANY_OPENED_FILES;

+ 8 - 4
SMBLibrary/Server/SMB1/ReadWriteResponseHelper.cs

@@ -19,7 +19,8 @@ namespace SMBLibrary.Server.SMB1
     {
         internal static SMB1Command GetReadResponse(SMB1Header header, ReadRequest request, ISMBShare share, SMB1ConnectionState state)
         {
-            OpenFileObject openFile = state.GetOpenFileObject(request.FID);
+            SMB1Session session = state.GetSession(header.UID);
+            OpenFileObject openFile = session.GetOpenFileObject(request.FID);
             if (openFile == null)
             {
                 header.Status = NTStatus.STATUS_INVALID_HANDLE;
@@ -40,7 +41,8 @@ namespace SMBLibrary.Server.SMB1
 
         internal static SMB1Command GetReadResponse(SMB1Header header, ReadAndXRequest request, ISMBShare share, SMB1ConnectionState state)
         {
-            OpenFileObject openFile = state.GetOpenFileObject(request.FID);
+            SMB1Session session = state.GetSession(header.UID);
+            OpenFileObject openFile = session.GetOpenFileObject(request.FID);
             if (openFile == null)
             {
                 header.Status = NTStatus.STATUS_INVALID_HANDLE;
@@ -136,7 +138,8 @@ namespace SMBLibrary.Server.SMB1
 
         internal static SMB1Command GetWriteResponse(SMB1Header header, WriteRequest request, ISMBShare share, SMB1ConnectionState state)
         {
-            OpenFileObject openFile = state.GetOpenFileObject(request.FID);
+            SMB1Session session = state.GetSession(header.UID);
+            OpenFileObject openFile = session.GetOpenFileObject(request.FID);
             if (openFile == null)
             {
                 header.Status = NTStatus.STATUS_INVALID_HANDLE;
@@ -155,7 +158,8 @@ namespace SMBLibrary.Server.SMB1
 
         internal static SMB1Command GetWriteResponse(SMB1Header header, WriteAndXRequest request, ISMBShare share, SMB1ConnectionState state)
         {
-            OpenFileObject openFile = state.GetOpenFileObject(request.FID);
+            SMB1Session session = state.GetSession(header.UID);
+            OpenFileObject openFile = session.GetOpenFileObject(request.FID);
             if (openFile == null)
             {
                 header.Status = NTStatus.STATUS_INVALID_HANDLE;

+ 6 - 3
SMBLibrary/Server/SMB1/ServerResponseHelper.cs

@@ -17,14 +17,16 @@ namespace SMBLibrary.Server.SMB1
     {
         internal static SMB1Command GetCloseResponse(SMB1Header header, CloseRequest request, ISMBShare share, SMB1ConnectionState state)
         {
-            OpenFileObject openFile = state.GetOpenFileObject(request.FID);
+            SMB1Session session = state.GetSession(header.UID);
+            OpenFileObject openFile = session.GetOpenFileObject(request.FID);
             if (openFile == null)
             {
                 header.Status = NTStatus.STATUS_SMB_BAD_FID;
                 return new ErrorResponse(CommandName.SMB_COM_CLOSE);
             }
 
-            state.RemoveOpenFile(request.FID);
+            state.LogToServer(Severity.Verbose, "Close: Closing file '{0}'", openFile.Path);
+            session.RemoveOpenFile(request.FID);
             if (openFile.DeleteOnClose && share is FileSystemShare)
             {
                 try
@@ -42,7 +44,8 @@ namespace SMBLibrary.Server.SMB1
 
         internal static SMB1Command GetFindClose2Request(SMB1Header header, FindClose2Request request, SMB1ConnectionState state)
         {
-            state.ReleaseSearchHandle(request.SearchHandle);
+            SMB1Session session = state.GetSession(header.UID);
+            session.ReleaseSearchHandle(request.SearchHandle);
             return new FindClose2Response();
         }
 

+ 20 - 20
SMBLibrary/Server/SMB1/SessionSetupHelper.cs

@@ -40,25 +40,25 @@ namespace SMBLibrary.Server.SMB1
             if (loginSuccess)
             {
                 state.LogToServer(Severity.Information, "User '{0}' authenticated successfully", message.UserName);
-                ushort? userID = state.AddConnectedUser(message.UserName);
-                if (!userID.HasValue)
+                SMB1Session session = state.CreateSession(message.UserName);
+                if (session == null)
                 {
                     header.Status = NTStatus.STATUS_TOO_MANY_SESSIONS;
                     return new ErrorResponse(CommandName.SMB_COM_SESSION_SETUP_ANDX);
                 }
-                header.UID = userID.Value;
+                header.UID = session.UserID;
                 response.PrimaryDomain = request.PrimaryDomain;
             }
             else if (users.FallbackToGuest(message.UserName))
             {
                 state.LogToServer(Severity.Information, "User '{0}' failed authentication. logged in as guest", message.UserName);
-                ushort? userID = state.AddConnectedUser("Guest");
-                if (!userID.HasValue)
+                SMB1Session session = state.CreateSession("Guest");
+                if (session == null)
                 {
                     header.Status = NTStatus.STATUS_TOO_MANY_SESSIONS;
                     return new ErrorResponse(CommandName.SMB_COM_SESSION_SETUP_ANDX);
                 }
-                header.UID = userID.Value;
+                header.UID = session.UserID;
                 response.Action = SessionSetupAction.SetupGuest;
                 response.PrimaryDomain = request.PrimaryDomain;
             }
@@ -100,6 +100,18 @@ namespace SMBLibrary.Server.SMB1
                 return new ErrorResponse(CommandName.SMB_COM_SESSION_SETUP_ANDX);
             }
 
+            // According to [MS-SMB] 3.3.5.3, a UID MUST be allocated if the server returns STATUS_MORE_PROCESSING_REQUIRED
+            if (header.UID == 0)
+            {
+                ushort? userID = state.AllocateUserID();
+                if (!userID.HasValue)
+                {
+                    header.Status = NTStatus.STATUS_TOO_MANY_SESSIONS;
+                    return new ErrorResponse(request.CommandName);
+                }
+                header.UID = userID.Value;
+            }
+
             MessageTypeName messageType = AuthenticationMessageUtils.GetMessageType(messageBytes);
             if (messageType == MessageTypeName.Negotiate)
             {
@@ -133,24 +145,12 @@ namespace SMBLibrary.Server.SMB1
                 if (loginSuccess)
                 {
                     state.LogToServer(Severity.Information, "User '{0}' authenticated successfully", authenticateMessage.UserName);
-                    ushort? userID = state.AddConnectedUser(authenticateMessage.UserName);
-                    if (!userID.HasValue)
-                    {
-                        header.Status = NTStatus.STATUS_TOO_MANY_SESSIONS;
-                        return new ErrorResponse(CommandName.SMB_COM_SESSION_SETUP_ANDX);
-                    }
-                    header.UID = userID.Value;
+                    state.CreateSession(header.UID, authenticateMessage.UserName);
                 }
                 else if (users.FallbackToGuest(authenticateMessage.UserName))
                 {
                     state.LogToServer(Severity.Information, "User '{0}' failed authentication. logged in as guest", authenticateMessage.UserName);
-                    ushort? userID = state.AddConnectedUser("Guest");
-                    if (!userID.HasValue)
-                    {
-                        header.Status = NTStatus.STATUS_TOO_MANY_SESSIONS;
-                        return new ErrorResponse(CommandName.SMB_COM_SESSION_SETUP_ANDX);
-                    }
-                    header.UID = userID.Value;
+                    state.CreateSession(header.UID, "Guest");
                     response.Action = SessionSetupAction.SetupGuest;
                 }
                 else

+ 14 - 12
SMBLibrary/Server/SMB1/Transaction2SubcommandHelper.cs

@@ -22,6 +22,7 @@ namespace SMBLibrary.Server.SMB1
 
         internal static Transaction2FindFirst2Response GetSubcommandResponse(SMB1Header header, Transaction2FindFirst2Request subcommand, FileSystemShare share, SMB1ConnectionState state)
         {
+            SMB1Session session = state.GetSession(header.UID);
             IFileSystem fileSystem = share.FileSystem;
             string path = subcommand.FileName;
             // '\Directory' - Get the directory info
@@ -122,7 +123,7 @@ namespace SMBLibrary.Server.SMB1
             }
             else
             {
-                ushort? searchHandle = state.AllocateSearchHandle();
+                ushort? searchHandle = session.AllocateSearchHandle();
                 if (!searchHandle.HasValue)
                 {
                     header.Status = NTStatus.STATUS_OS2_NO_MORE_SIDS;
@@ -130,7 +131,7 @@ namespace SMBLibrary.Server.SMB1
                 }
                 response.SID = searchHandle.Value;
                 entries.RemoveRange(0, returnCount);
-                state.OpenSearches.Add(response.SID, entries);
+                session.OpenSearches.Add(response.SID, entries);
             }
             return response;
         }
@@ -195,14 +196,15 @@ namespace SMBLibrary.Server.SMB1
 
         internal static Transaction2FindNext2Response GetSubcommandResponse(SMB1Header header, Transaction2FindNext2Request subcommand, FileSystemShare share, SMB1ConnectionState state)
         {
-            if (!state.OpenSearches.ContainsKey(subcommand.SID))
+            SMB1Session session = state.GetSession(header.UID);
+            if (!session.OpenSearches.ContainsKey(subcommand.SID))
             {
                 header.Status = NTStatus.STATUS_INVALID_HANDLE;
                 return null;
             }
 
             bool returnResumeKeys = (subcommand.Flags & FindFlags.SMB_FIND_RETURN_RESUME_KEYS) > 0;
-            List<FileSystemEntry> entries = state.OpenSearches[subcommand.SID];
+            List<FileSystemEntry> entries = session.OpenSearches[subcommand.SID];
             FindInformationList findInformationList = new FindInformationList();
             for (int index = 0; index < entries.Count; index++)
             {
@@ -218,11 +220,11 @@ namespace SMBLibrary.Server.SMB1
             Transaction2FindNext2Response response = new Transaction2FindNext2Response();
             response.SetFindInformationList(findInformationList, header.UnicodeFlag);
             entries.RemoveRange(0, returnCount);
-            state.OpenSearches[subcommand.SID] = entries;
+            session.OpenSearches[subcommand.SID] = entries;
             response.EndOfSearch = (returnCount == entries.Count) && (entries.Count <= subcommand.SearchCount);
             if (response.EndOfSearch)
             {
-                state.ReleaseSearchHandle(subcommand.SID);
+                session.ReleaseSearchHandle(subcommand.SID);
             }
             return response;
         }
@@ -257,8 +259,9 @@ namespace SMBLibrary.Server.SMB1
 
         internal static Transaction2QueryFileInformationResponse GetSubcommandResponse(SMB1Header header, Transaction2QueryFileInformationRequest subcommand, FileSystemShare share, SMB1ConnectionState state)
         {
+            SMB1Session session = state.GetSession(header.UID);
             IFileSystem fileSystem = share.FileSystem;
-            OpenFileObject openFile = state.GetOpenFileObject(subcommand.FID);
+            OpenFileObject openFile = session.GetOpenFileObject(subcommand.FID);
             if (openFile == null)
             {
                 header.Status = NTStatus.STATUS_INVALID_HANDLE;
@@ -280,7 +283,8 @@ namespace SMBLibrary.Server.SMB1
 
         internal static Transaction2SetFileInformationResponse GetSubcommandResponse(SMB1Header header, Transaction2SetFileInformationRequest subcommand, FileSystemShare share, SMB1ConnectionState state)
         {
-            OpenFileObject openFile = state.GetOpenFileObject(subcommand.FID);
+            SMB1Session session = state.GetSession(header.UID);
+            OpenFileObject openFile = session.GetOpenFileObject(subcommand.FID);
             if (openFile == null)
             {
                 header.Status = NTStatus.STATUS_INVALID_HANDLE;
@@ -300,8 +304,7 @@ namespace SMBLibrary.Server.SMB1
                 }
                 case SetInformationLevel.SMB_SET_FILE_BASIC_INFO:
                 {
-                    string userName = state.GetConnectedUserName(header.UID);
-                    if (!share.HasWriteAccess(userName))
+                    if (!share.HasWriteAccess(session.UserName))
                     {
                         header.Status = NTStatus.STATUS_ACCESS_DENIED;
                         return null;
@@ -353,8 +356,7 @@ namespace SMBLibrary.Server.SMB1
                     if (((SetFileDispositionInfo)subcommand.SetInfo).DeletePending)
                     {
                         // We're supposed to delete the file on close, but it's too late to report errors at this late stage
-                        string userName = state.GetConnectedUserName(header.UID);
-                        if (!share.HasWriteAccess(userName))
+                        if (!share.HasWriteAccess(session.UserName))
                         {
                             header.Status = NTStatus.STATUS_ACCESS_DENIED;
                             return null;

+ 2 - 1
SMBLibrary/Server/SMB1/TransactionSubcommandHelper.cs

@@ -18,7 +18,8 @@ namespace SMBLibrary.Server.SMB1
     {
         internal static TransactionTransactNamedPipeResponse GetSubcommandResponse(SMB1Header header, TransactionTransactNamedPipeRequest subcommand, NamedPipeShare share, SMB1ConnectionState state)
         {
-            OpenFileObject openFile = state.GetOpenFileObject(subcommand.FID);
+            SMB1Session session = state.GetSession(header.UID);
+            OpenFileObject openFile = session.GetOpenFileObject(subcommand.FID);
             if (openFile == null)
             {
                 header.Status = NTStatus.STATUS_INVALID_HANDLE;

+ 6 - 5
SMBLibrary/Server/SMB1/TreeConnectHelper.cs

@@ -16,6 +16,7 @@ namespace SMBLibrary.Server.SMB1
     {
         internal static SMB1Command GetTreeConnectResponse(SMB1Header header, TreeConnectAndXRequest request, SMB1ConnectionState state, NamedPipeShare services, ShareCollection shares)
         {
+            SMB1Session session = state.GetSession(header.UID);
             bool isExtended = (request.Flags & TreeConnectFlags.ExtendedResponse) > 0;
             string shareName = ServerPathUtils.GetShareName(request.Path);
             ISMBShare share;
@@ -35,14 +36,13 @@ namespace SMBLibrary.Server.SMB1
                     return new ErrorResponse(CommandName.SMB_COM_TREE_CONNECT_ANDX);
                 }
 
-                string userName = state.GetConnectedUserName(header.UID);
-                if (!((FileSystemShare)share).HasReadAccess(userName))
+                if (!((FileSystemShare)share).HasReadAccess(session.UserName))
                 {
                     header.Status = NTStatus.STATUS_ACCESS_DENIED;
                     return new ErrorResponse(CommandName.SMB_COM_TREE_CONNECT_ANDX);
                 }
             }
-            ushort? treeID = state.AddConnectedTree(share);
+            ushort? treeID = session.AddConnectedTree(share);
             if (!treeID.HasValue)
             {
                 header.Status = NTStatus.STATUS_INSUFF_SERVER_RESOURCES;
@@ -88,13 +88,14 @@ namespace SMBLibrary.Server.SMB1
 
         internal static SMB1Command GetTreeDisconnectResponse(SMB1Header header, TreeDisconnectRequest request, SMB1ConnectionState state)
         {
-            if (!state.IsTreeConnected(header.TID))
+            SMB1Session session = state.GetSession(header.UID);
+            if (!session.IsTreeConnected(header.TID))
             {
                 header.Status = NTStatus.STATUS_SMB_BAD_TID;
                 return new ErrorResponse(CommandName.SMB_COM_TREE_DISCONNECT);
             }
 
-            state.RemoveConnectedTree(header.TID);
+            session.RemoveConnectedTree(header.TID);
             return new TreeDisconnectResponse();
         }
     }

+ 19 - 12
SMBLibrary/Server/SMBServer.SMB1.cs

@@ -88,8 +88,15 @@ namespace SMBLibrary.Server
             {
                 return ServerResponseHelper.GetEchoResponse((EchoRequest)command, sendQueue);
             }
-            else if (state.IsAuthenticated(header.UID))
+            else
             {
+                SMB1Session session = state.GetSession(header.UID);
+                if (session == null)
+                {
+                    header.Status = NTStatus.STATUS_USER_SESSION_DELETED;
+                    return new ErrorResponse(command.CommandName);
+                }
+
                 if (command is TreeConnectAndXRequest)
                 {
                     TreeConnectAndXRequest request = (TreeConnectAndXRequest)command;
@@ -97,13 +104,18 @@ namespace SMBLibrary.Server
                 }
                 else if (command is LogoffAndXRequest)
                 {
-                    // FIXME: Remove connected trees that the user has connected to
-                    state.RemoveConnectedUser(header.UID);
+                    state.RemoveSession(header.UID);
                     return new LogoffAndXResponse();
                 }
-                else if (state.IsTreeConnected(header.TID))
+                else
                 {
-                    ISMBShare share = state.GetConnectedTree(header.TID);
+                    ISMBShare share = session.GetConnectedTree(header.TID);
+                    if (share == null)
+                    {
+                        header.Status = NTStatus.STATUS_SMB_BAD_TID;
+                        return new ErrorResponse(command.CommandName);
+                    }
+
                     if (command is CreateDirectoryRequest)
                     {
                         if (!(share is FileSystemShare))
@@ -180,7 +192,7 @@ namespace SMBLibrary.Server
                     }
                     else if (command is WriteRequest)
                     {
-                        string userName = state.GetConnectedUserName(header.UID);
+                        string userName = session.UserName;
                         if (share is FileSystemShare && !((FileSystemShare)share).HasWriteAccess(userName))
                         {
                             header.Status = NTStatus.STATUS_ACCESS_DENIED;
@@ -234,7 +246,7 @@ namespace SMBLibrary.Server
                     }
                     else if (command is WriteAndXRequest)
                     {
-                        string userName = state.GetConnectedUserName(header.UID);
+                        string userName = session.UserName;
                         if (share is FileSystemShare && !((FileSystemShare)share).HasWriteAccess(userName))
                         {
                             header.Status = NTStatus.STATUS_ACCESS_DENIED;
@@ -294,11 +306,6 @@ namespace SMBLibrary.Server
                         return NTCreateHelper.GetNTCreateResponse(header, request, share, state);
                     }
                 }
-                else
-                {
-                    header.Status = NTStatus.STATUS_SMB_BAD_TID;
-                    return new ErrorResponse(command.CommandName);
-                }
             }
 
             header.Status = NTStatus.STATUS_SMB_BAD_COMMAND;