From e62744bb69bf795e15b9f667b7b4a46b11744ae7 Mon Sep 17 00:00:00 2001
From: Jordan Borean <jborean93@gmail.com>
Date: Sat, 8 May 2021 08:00:23 +1000
Subject: [PATCH] Fix up connection cache for scandir and rmtree (#101)

---
 CHANGELOG.md               |  2 ++
 smbclient/_os.py           | 15 ++++++----
 smbclient/shutil.py        |  8 ++---
 tests/test_smbclient_os.py | 60 ++++++++++++++++++++++++++++++++++++++
 4 files changed, 75 insertions(+), 10 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index 9384ae8..02dc65b 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -5,6 +5,8 @@
 * Unified DFS path handling when using any API that uses a transaction to open the file
   * This includes `smbclient.rename` and `smbclient.replace`
 * Fixed up `smbclient.rename` to work with directories
+* `smbclient.scandir` will continue to use the connection cache when getting stat information of a dir entry
+* `smbclient.shutil.rmtree` will continue to use the connection cache when removing child entries
 
 
 ## 1.5.0 - 2021-03-25
diff --git a/smbclient/_os.py b/smbclient/_os.py
index 4877f1d..0d52a81 100644
--- a/smbclient/_os.py
+++ b/smbclient/_os.py
@@ -526,13 +526,15 @@ def scandir(path, search_pattern="*", **kwargs):
     :param kwargs: Common SMB Session arguments for smbclient.
     :return: An iterator of DirEntry objects in the directory.
     """
+    connection_cache = kwargs.get('connection_cache', None)
     with SMBDirectoryIO(path, share_access='rwd', **kwargs) as fd:
         for dir_info in fd.query_directory(search_pattern, FileInformationClass.FILE_ID_FULL_DIRECTORY_INFORMATION):
             filename = dir_info['file_name'].get_value().decode('utf-16-le')
             if filename in [u'.', u'..']:
                 continue
 
-            dir_entry = SMBDirEntry(SMBRawIO(u"%s\\%s" % (path, filename), **kwargs), dir_info)
+            dir_entry = SMBDirEntry(SMBRawIO(u"%s\\%s" % (path, filename), **kwargs), dir_info,
+                                    connection_cache=connection_cache)
             yield dir_entry
 
 
@@ -1093,11 +1095,12 @@ def _set_basic_information(path, creation_time=0, last_access_time=0, last_write
 
 class SMBDirEntry(object):
 
-    def __init__(self, raw, dir_info):
+    def __init__(self, raw, dir_info, connection_cache=None):
         self._smb_raw = raw
         self._dir_info = dir_info
         self._stat = None
         self._lstat = None
+        self._connection_cache = connection_cache
 
     def __str__(self):
         return '<{0}: {1!r}>'.format(self.__class__.__name__, to_native(self.name))
@@ -1208,16 +1211,16 @@ class SMBDirEntry(object):
         if follow_symlinks:
             if not self._stat:
                 if self.is_symlink():
-                    self._stat = stat(self.path)
+                    self._stat = stat(self.path, connection_cache=self._connection_cache)
                 else:
                     # Because it's not a symlink lstat will be the same as stat so set both.
                     if self._lstat is None:
-                        self._lstat = lstat(self._smb_raw.name)
+                        self._lstat = lstat(self._smb_raw.name, connection_cache=self._connection_cache)
                     self._stat = self._lstat
             return self._stat
         else:
             if not self._lstat:
-                self._lstat = lstat(self.path)
+                self._lstat = lstat(self.path, connection_cache=self._connection_cache)
             return self._lstat
 
     @classmethod
@@ -1229,7 +1232,7 @@ class SMBDirEntry(object):
         dir_info['file_attributes'] = file_stat.st_file_attributes
         dir_info['file_id'] = file_stat.st_ino
 
-        dir_entry = cls(SMBRawIO(path, **kwargs), dir_info)
+        dir_entry = cls(SMBRawIO(path, **kwargs), dir_info, connection_cache=kwargs.get('connection_cache', None))
         dir_entry._stat = file_stat
         return dir_entry
 
diff --git a/smbclient/shutil.py b/smbclient/shutil.py
index 51e09b9..52f0d53 100644
--- a/smbclient/shutil.py
+++ b/smbclient/shutil.py
@@ -408,19 +408,19 @@ def rmtree(path, ignore_errors=False, onerror=None, **kwargs):
         if dir_entry.is_symlink() and \
                 dir_entry.stat(follow_symlinks=False).st_file_attributes & FileAttributes.FILE_ATTRIBUTE_DIRECTORY:
             try:
-                rmdir(dir_entry.path)
+                rmdir(dir_entry.path, **kwargs)
             except OSError:
                 onerror(rmdir, dir_entry.path, sys.exc_info())
         elif dir_entry.is_dir():
-            rmtree(dir_entry.path, ignore_errors, onerror)
+            rmtree(dir_entry.path, ignore_errors, onerror, **kwargs)
         else:
             try:
-                remove(dir_entry.path)
+                remove(dir_entry.path, **kwargs)
             except OSError:
                 onerror(remove, dir_entry.path, sys.exc_info())
 
     try:
-        rmdir(path)
+        rmdir(path, **kwargs)
     except OSError:
         onerror(rmdir, path, sys.exc_info())
 
diff --git a/tests/test_smbclient_os.py b/tests/test_smbclient_os.py
index 5c43fbd..a6427a1 100644
--- a/tests/test_smbclient_os.py
+++ b/tests/test_smbclient_os.py
@@ -13,6 +13,7 @@ import re
 import six
 import smbclient  # Tests that we expose this in smbclient/__init__.py
 import stat
+import time
 
 from smbclient._io import (
     query_info,
@@ -25,6 +26,10 @@ from smbclient._os import (
     SMBFileIO,
 )
 
+from smbclient.shutil import (
+    rmtree,
+)
+
 from smbprotocol.exceptions import (
     SMBAuthenticationError,
     SMBOSError,
@@ -1320,6 +1325,61 @@ def test_scandir_with_broken_symlink(smb_share):
         assert entry.is_file(follow_symlinks=False) is False  # broken link target
 
 
+def test_scandir_with_cache(smb_real):
+    share_path = u"%s\\%s" % (smb_real[4], u"Pýtæs†-[%s] 💩" % time.time())
+    cache = {}
+    smbclient.mkdir(share_path, username=smb_real[0], password=smb_real[1], port=smb_real[3], connection_cache=cache)
+
+    try:
+
+        dir_path = ntpath.join(share_path, 'directory')
+        smbclient.makedirs(dir_path, exist_ok=True, connection_cache=cache)
+
+        for name in ['file.txt', u'unicode †[💩].txt']:
+            with smbclient.open_file(ntpath.join(dir_path, name), mode='w', connection_cache=cache) as fd:
+                fd.write(u"content")
+
+        for name in ['subdir1', 'subdir2', u'unicode dir †[💩]', 'subdir1\\sub']:
+            smbclient.mkdir(ntpath.join(dir_path, name), connection_cache=cache)
+
+        count = 0
+        names = []
+        for dir_entry in smbclient.scandir(dir_path, connection_cache=cache):
+            assert isinstance(dir_entry, SMBDirEntry)
+            names.append(dir_entry.name)
+
+            # Test out dir_entry for specific file and dir examples
+            if dir_entry.name == 'subdir1':
+                assert str(dir_entry) == "<SMBDirEntry: 'subdir1'>"
+                assert dir_entry.is_dir() is True
+                assert dir_entry.is_file() is False
+                assert dir_entry.stat(follow_symlinks=False).st_ino == dir_entry.inode()
+                assert dir_entry.stat().st_ino == dir_entry.inode()
+            elif dir_entry.name == 'file.txt':
+                assert str(dir_entry) == "<SMBDirEntry: 'file.txt'>"
+                assert dir_entry.is_dir() is False
+                assert dir_entry.is_file() is True
+                assert dir_entry.stat().st_ino == dir_entry.inode()
+                assert dir_entry.stat(follow_symlinks=False).st_ino == dir_entry.inode()
+
+            assert dir_entry.is_symlink() is False
+            assert dir_entry.inode() is not None
+            assert dir_entry.inode() == dir_entry.stat().st_ino
+
+            count += 1
+
+        assert count == 5
+        assert u'unicode †[💩].txt' in names
+        assert u'unicode dir †[💩]' in names
+        assert u'subdir2' in names
+        assert u'subdir1' in names
+        assert u'file.txt' in names
+
+    finally:
+        rmtree(share_path, connection_cache=cache)
+        smbclient.reset_connection_cache(connection_cache=cache)
+
+
 def test_stat_directory(smb_share):
     actual = smbclient.stat(smb_share)
     assert isinstance(actual, smbclient.SMBStatResult)
-- 
GitLab