From e62744bb69bf795e15b9f667b7b4a46b11744ae7 Mon Sep 17 00:00:00 2001 From: Jordan Borean <jborean93@gmail.com> Date: Sat, 8 May 2021 08:00:23 +1000 Subject: [PATCH] Fix up connection cache for scandir and rmtree (#101) --- CHANGELOG.md | 2 ++ smbclient/_os.py | 15 ++++++---- smbclient/shutil.py | 8 ++--- tests/test_smbclient_os.py | 60 ++++++++++++++++++++++++++++++++++++++ 4 files changed, 75 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9384ae8..02dc65b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,8 @@ * Unified DFS path handling when using any API that uses a transaction to open the file * This includes `smbclient.rename` and `smbclient.replace` * Fixed up `smbclient.rename` to work with directories +* `smbclient.scandir` will continue to use the connection cache when getting stat information of a dir entry +* `smbclient.shutil.rmtree` will continue to use the connection cache when removing child entries ## 1.5.0 - 2021-03-25 diff --git a/smbclient/_os.py b/smbclient/_os.py index 4877f1d..0d52a81 100644 --- a/smbclient/_os.py +++ b/smbclient/_os.py @@ -526,13 +526,15 @@ def scandir(path, search_pattern="*", **kwargs): :param kwargs: Common SMB Session arguments for smbclient. :return: An iterator of DirEntry objects in the directory. """ + connection_cache = kwargs.get('connection_cache', None) with SMBDirectoryIO(path, share_access='rwd', **kwargs) as fd: for dir_info in fd.query_directory(search_pattern, FileInformationClass.FILE_ID_FULL_DIRECTORY_INFORMATION): filename = dir_info['file_name'].get_value().decode('utf-16-le') if filename in [u'.', u'..']: continue - dir_entry = SMBDirEntry(SMBRawIO(u"%s\\%s" % (path, filename), **kwargs), dir_info) + dir_entry = SMBDirEntry(SMBRawIO(u"%s\\%s" % (path, filename), **kwargs), dir_info, + connection_cache=connection_cache) yield dir_entry @@ -1093,11 +1095,12 @@ def _set_basic_information(path, creation_time=0, last_access_time=0, last_write class SMBDirEntry(object): - def __init__(self, raw, dir_info): + def __init__(self, raw, dir_info, connection_cache=None): self._smb_raw = raw self._dir_info = dir_info self._stat = None self._lstat = None + self._connection_cache = connection_cache def __str__(self): return '<{0}: {1!r}>'.format(self.__class__.__name__, to_native(self.name)) @@ -1208,16 +1211,16 @@ class SMBDirEntry(object): if follow_symlinks: if not self._stat: if self.is_symlink(): - self._stat = stat(self.path) + self._stat = stat(self.path, connection_cache=self._connection_cache) else: # Because it's not a symlink lstat will be the same as stat so set both. if self._lstat is None: - self._lstat = lstat(self._smb_raw.name) + self._lstat = lstat(self._smb_raw.name, connection_cache=self._connection_cache) self._stat = self._lstat return self._stat else: if not self._lstat: - self._lstat = lstat(self.path) + self._lstat = lstat(self.path, connection_cache=self._connection_cache) return self._lstat @classmethod @@ -1229,7 +1232,7 @@ class SMBDirEntry(object): dir_info['file_attributes'] = file_stat.st_file_attributes dir_info['file_id'] = file_stat.st_ino - dir_entry = cls(SMBRawIO(path, **kwargs), dir_info) + dir_entry = cls(SMBRawIO(path, **kwargs), dir_info, connection_cache=kwargs.get('connection_cache', None)) dir_entry._stat = file_stat return dir_entry diff --git a/smbclient/shutil.py b/smbclient/shutil.py index 51e09b9..52f0d53 100644 --- a/smbclient/shutil.py +++ b/smbclient/shutil.py @@ -408,19 +408,19 @@ def rmtree(path, ignore_errors=False, onerror=None, **kwargs): if dir_entry.is_symlink() and \ dir_entry.stat(follow_symlinks=False).st_file_attributes & FileAttributes.FILE_ATTRIBUTE_DIRECTORY: try: - rmdir(dir_entry.path) + rmdir(dir_entry.path, **kwargs) except OSError: onerror(rmdir, dir_entry.path, sys.exc_info()) elif dir_entry.is_dir(): - rmtree(dir_entry.path, ignore_errors, onerror) + rmtree(dir_entry.path, ignore_errors, onerror, **kwargs) else: try: - remove(dir_entry.path) + remove(dir_entry.path, **kwargs) except OSError: onerror(remove, dir_entry.path, sys.exc_info()) try: - rmdir(path) + rmdir(path, **kwargs) except OSError: onerror(rmdir, path, sys.exc_info()) diff --git a/tests/test_smbclient_os.py b/tests/test_smbclient_os.py index 5c43fbd..a6427a1 100644 --- a/tests/test_smbclient_os.py +++ b/tests/test_smbclient_os.py @@ -13,6 +13,7 @@ import re import six import smbclient # Tests that we expose this in smbclient/__init__.py import stat +import time from smbclient._io import ( query_info, @@ -25,6 +26,10 @@ from smbclient._os import ( SMBFileIO, ) +from smbclient.shutil import ( + rmtree, +) + from smbprotocol.exceptions import ( SMBAuthenticationError, SMBOSError, @@ -1320,6 +1325,61 @@ def test_scandir_with_broken_symlink(smb_share): assert entry.is_file(follow_symlinks=False) is False # broken link target +def test_scandir_with_cache(smb_real): + share_path = u"%s\\%s" % (smb_real[4], u"Pýtæs†-[%s] 💩" % time.time()) + cache = {} + smbclient.mkdir(share_path, username=smb_real[0], password=smb_real[1], port=smb_real[3], connection_cache=cache) + + try: + + dir_path = ntpath.join(share_path, 'directory') + smbclient.makedirs(dir_path, exist_ok=True, connection_cache=cache) + + for name in ['file.txt', u'unicode †[💩].txt']: + with smbclient.open_file(ntpath.join(dir_path, name), mode='w', connection_cache=cache) as fd: + fd.write(u"content") + + for name in ['subdir1', 'subdir2', u'unicode dir †[💩]', 'subdir1\\sub']: + smbclient.mkdir(ntpath.join(dir_path, name), connection_cache=cache) + + count = 0 + names = [] + for dir_entry in smbclient.scandir(dir_path, connection_cache=cache): + assert isinstance(dir_entry, SMBDirEntry) + names.append(dir_entry.name) + + # Test out dir_entry for specific file and dir examples + if dir_entry.name == 'subdir1': + assert str(dir_entry) == "<SMBDirEntry: 'subdir1'>" + assert dir_entry.is_dir() is True + assert dir_entry.is_file() is False + assert dir_entry.stat(follow_symlinks=False).st_ino == dir_entry.inode() + assert dir_entry.stat().st_ino == dir_entry.inode() + elif dir_entry.name == 'file.txt': + assert str(dir_entry) == "<SMBDirEntry: 'file.txt'>" + assert dir_entry.is_dir() is False + assert dir_entry.is_file() is True + assert dir_entry.stat().st_ino == dir_entry.inode() + assert dir_entry.stat(follow_symlinks=False).st_ino == dir_entry.inode() + + assert dir_entry.is_symlink() is False + assert dir_entry.inode() is not None + assert dir_entry.inode() == dir_entry.stat().st_ino + + count += 1 + + assert count == 5 + assert u'unicode †[💩].txt' in names + assert u'unicode dir †[💩]' in names + assert u'subdir2' in names + assert u'subdir1' in names + assert u'file.txt' in names + + finally: + rmtree(share_path, connection_cache=cache) + smbclient.reset_connection_cache(connection_cache=cache) + + def test_stat_directory(smb_share): actual = smbclient.stat(smb_share) assert isinstance(actual, smbclient.SMBStatResult) -- GitLab