Browse Source

All matching opened files will now be closed during tree disconnect

Tal Aloni 8 years ago
parent
commit
cff82dd4b8

+ 17 - 2
SMBLibrary/Server/ConnectionState/SMB1Session.cs

@@ -57,9 +57,24 @@ namespace SMBLibrary.Server
             return share;
         }
 
-        public void RemoveConnectedTree(ushort treeID)
+        public void DisconnectTree(ushort treeID)
         {
-            m_connectedTrees.Remove(treeID);
+            ISMBShare share;
+            m_connectedTrees.TryGetValue(treeID, out share);
+            if (share != null)
+            {
+                List<ushort> fileIDList = new List<ushort>(m_openFiles.Keys);
+                foreach (ushort fileID in fileIDList)
+                {
+                    OpenFileObject openFile = m_openFiles[fileID];
+                    if (openFile.TreeID == treeID)
+                    {
+                        share.FileStore.CloseFile(openFile.Handle);
+                        m_openFiles.Remove(fileID);
+                    }
+                }
+                m_connectedTrees.Remove(treeID);
+            }
         }
 
         public bool IsTreeConnected(ushort treeID)

+ 17 - 7
SMBLibrary/Server/ConnectionState/SMB2Session.cs

@@ -79,14 +79,24 @@ namespace SMBLibrary.Server
             }
         }
 
-        public void RemoveConnectedTree(uint treeID)
+        public void DisconnectTree(uint treeID)
         {
-            m_connectedTrees.Remove(treeID);
-        }
-
-        public void RemoveConnectedTrees()
-        {
-            m_connectedTrees.Clear();
+            ISMBShare share;
+            m_connectedTrees.TryGetValue(treeID, out share);
+            if (share != null)
+            {
+                List<ulong> fileIDList = new List<ulong>(m_openFiles.Keys);
+                foreach (ushort fileID in fileIDList)
+                {
+                    OpenFileObject openFile = m_openFiles[fileID];
+                    if (openFile.TreeID == treeID)
+                    {
+                        share.FileStore.CloseFile(openFile.Handle);
+                        m_openFiles.Remove(fileID);
+                    }
+                }
+                m_connectedTrees.Remove(treeID);
+            }
         }
 
         public bool IsTreeConnected(uint treeID)

+ 1 - 1
SMBLibrary/Server/SMB1/TreeConnectHelper.cs

@@ -98,7 +98,7 @@ namespace SMBLibrary.Server.SMB1
                 return new ErrorResponse(request.CommandName);
             }
 
-            session.RemoveConnectedTree(header.TID);
+            session.DisconnectTree(header.TID);
             return new TreeDisconnectResponse();
         }
     }

+ 1 - 1
SMBLibrary/Server/SMBServer.SMB2.cs

@@ -158,7 +158,7 @@ namespace SMBLibrary.Server
 
                     if (command is TreeDisconnectRequest)
                     {
-                        session.RemoveConnectedTree(command.Header.TreeID);
+                        session.DisconnectTree(command.Header.TreeID);
                         return new TreeDisconnectResponse();
                     }
                     else if (command is CreateRequest)