From 6840a52028e00df1a4b03f948d28bddf17db4377 Mon Sep 17 00:00:00 2001
From: Jordan Borean <jborean93@gmail.com>
Date: Thu, 21 Oct 2021 05:35:17 +1000
Subject: [PATCH] Add newer encryption types and signing algorithms (#147)

This adds the new encryption types AES 256 GCM/CCM and support for
negotiating the signing algorithm, including AES GCM supported on newer
SMB 3.1.1 dialects. These algorithms were introduced in Server
2022/Windows 11 and Samba 4.15.x and can improve the speed of
encryption/signatures dependending on the CPUs involved.

The encryption and signing algorithms can now also be specified by the
caller when calling connection.connect() in order to influence what is
offered by the client and its priority order.

The negotiate NETNAME_CONTEXT_ID negotiate context is also added to the
negotiate request as per the updated guidelines for SMB. This new
context doesn't influence the negotiation process but is used for
network inspection tools to determine the target of the request.
---
 .github/workflows/ci.yml     |  13 +-
 CHANGELOG.md                 |   6 +-
 build_helpers/lib.sh         |   2 +-
 build_helpers/samba-setup.sh |   2 +-
 setup.py                     |   2 +-
 smbprotocol/connection.py    | 279 ++++++++++++++++++++++++++---------
 smbprotocol/session.py       |  30 +++-
 tests/test_connection.py     | 162 +++++++++++++-------
 tests/test_smbclient_os.py   |   2 -
 tests/test_tree.py           |  49 ++++++
 10 files changed, 410 insertions(+), 137 deletions(-)

diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 4fea2e0..4415ee2 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -34,7 +34,8 @@ jobs:
         os:
         - ubuntu-latest
         - macOS-latest
-        - windows-latest
+        - windows-2019
+        - windows-2022
         python-version:
         - 3.6
         - 3.7
@@ -50,6 +51,16 @@ jobs:
           python-arch: x86
         - os: macOS-latest
           python-arch: x86
+        - os: windows-2019
+          python-arch: x86
+        - os: windows-2019
+          python-version: 3.6
+        - os: windows-2019
+          python-version: 3.7
+        - os: windows-2019
+          python-version: 3.8
+        - os: windows-2019
+          python-version: 3.9
 
     steps:
     - uses: actions/checkout@v2
diff --git a/CHANGELOG.md b/CHANGELOG.md
index ea6d7b8..1191481 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,7 +1,11 @@
 # Changelog
 
-## 1.7.1 - TBD
+## 1.8.0 - TBD
 
+* Added support for 256bit keyed encryption ciphers
+* Added support for signing with AES GCM
+* Now sends the `SMB2_NETNAME_NEGOTIATE_CONTEXT_ID` with the negotiate request
+* Adds the Python requirement of [pykrb5](https://github.com/jborean93/pykrb5) for Kerberos support on non-Windows
 * Fix unpacking security descriptor ACEs with extra data on the end - https://github.com/jborean93/smbprotocol/pull/143
 * Set `index_number` in `FileInternalInformation` to be an unsigned integer to match the other structures
 * Clear out expired DFS referrals to avoid memory leaks and stale DFS information - https://github.com/jborean93/smbprotocol/issues/136
diff --git a/build_helpers/lib.sh b/build_helpers/lib.sh
index c9b412f..182d7c5 100755
--- a/build_helpers/lib.sh
+++ b/build_helpers/lib.sh
@@ -35,7 +35,7 @@ lib::setup::smb_server() {
             --detach \
             --rm \
             --publish ${SMB_PORT}:445 \
-            --volume $( pwd )/build_helpers:/app \
+            --volume $( pwd )/build_helpers:/app:z \
             --workdir /app \
             archlinux:latest \
             /bin/bash \
diff --git a/build_helpers/samba-setup.sh b/build_helpers/samba-setup.sh
index 7883be9..619a6c0 100755
--- a/build_helpers/samba-setup.sh
+++ b/build_helpers/samba-setup.sh
@@ -16,7 +16,7 @@ valid users = @smbgroup
 server signing = mandatory
 ea support = yes
 store dos attributes = yes
-vfs objects = xattr_tdb streams_xattr
+vfs objects = streams_xattr xattr_tdb
 log level = 0
 
 [dfs]
diff --git a/setup.py b/setup.py
index d76adbb..53a3626 100644
--- a/setup.py
+++ b/setup.py
@@ -18,7 +18,7 @@ with open(abs_path('README.md'), mode='rb') as fd:
 
 setup(
     name='smbprotocol',
-    version='1.7.1',
+    version='1.8.0',
     packages=['smbclient', 'smbprotocol'],
     install_requires=[
         'cryptography>=2.0',
diff --git a/smbprotocol/connection.py b/smbprotocol/connection.py
index 5146a37..680d45f 100644
--- a/smbprotocol/connection.py
+++ b/smbprotocol/connection.py
@@ -6,7 +6,6 @@ import binascii
 import hashlib
 import hmac
 import logging
-import math
 import os
 import struct
 import time
@@ -16,10 +15,6 @@ from collections import (
     OrderedDict,
 )
 
-from cryptography.exceptions import (
-    UnsupportedAlgorithm,
-)
-
 from cryptography.hazmat.backends import (
     default_backend,
 )
@@ -77,6 +72,7 @@ from smbprotocol.structure import (
     FlagField,
     IntField,
     ListField,
+    TextField,
     Structure,
     StructureField,
     UuidField,
@@ -130,6 +126,11 @@ class NegotiateContextType(object):
     """
     SMB2_PREAUTH_INTEGRITY_CAPABILITIES = 0x0001
     SMB2_ENCRYPTION_CAPABILITIES = 0x0002
+    SMB2_COMPRESSION_CAPABILITIES = 0x0003
+    SMB2_NETNAME_NEGOTIATE_CONTEXT_ID = 0x0005
+    SMB2_TRANSPORT_CAPABILITIES = 0x0006
+    SMB2_RDMA_TRANSFORM_CAPABILITIES = 0x0007
+    SMB2_SIGNING_CAPABILITIES = 0x0008
 
 
 class HashAlgorithms(object):
@@ -141,12 +142,6 @@ class HashAlgorithms(object):
     """
     SHA_512 = 0x0001
 
-    @staticmethod
-    def get_algorithm(hash):
-        return {
-            HashAlgorithms.SHA_512: hashlib.sha512
-        }[hash]
-
 
 class Ciphers(object):
     """
@@ -157,28 +152,20 @@ class Ciphers(object):
     """
     AES_128_CCM = 0x0001
     AES_128_GCM = 0x0002
+    AES_256_CCM = 0x0003
+    AES_256_GCM = 0x0004
 
-    @staticmethod
-    def get_cipher(cipher):
-        return {
-            Ciphers.AES_128_CCM: aead.AESCCM,
-            Ciphers.AES_128_GCM: aead.AESGCM
-        }[cipher]
 
-    @staticmethod
-    def get_supported_ciphers():
-        supported_ciphers = []
-        try:
-            aead.AESGCM(b"\x00" * 16)
-            supported_ciphers.append(Ciphers.AES_128_GCM)
-        except UnsupportedAlgorithm:  # pragma: no cover
-            pass
-        try:
-            aead.AESCCM(b"\x00" * 16)
-            supported_ciphers.append(Ciphers.AES_128_CCM)
-        except UnsupportedAlgorithm:  # pragma: no cover
-            pass
-        return supported_ciphers
+class SigningAlgorithms:
+    """
+    [MS-SMB2] 2.2.3.1.7 SMB2_SIGNING_CAPABILITIES
+
+    https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-smb2/cb9b5d66-b6be-4d18-aa66-8784a871cc10
+    16-bit integer IDs that specify the supported signing algorithms.
+    """
+    HMAC_SHA256 = 0x0000
+    AES_CMAC = 0x0001
+    AES_GMAC = 0x0002
 
 
 class SMB2NegotiateRequest(Structure):
@@ -301,16 +288,20 @@ class SMB3NegotiateRequest(Structure):
         context_count = structure['negotiate_context_count'].get_value()
         context_list = []
         for idx in range(0, context_count):
-            field, data = self._parse_negotiate_context_entry(data, idx)
+            field, data = self._parse_negotiate_context_entry(data)
             context_list.append(field)
 
         return context_list
 
-    def _parse_negotiate_context_entry(self, data, idx):
+    def _parse_negotiate_context_entry(self, data):
         data_length = struct.unpack("<H", data[2:4])[0]
         negotiate_context = SMB2NegotiateContextRequest()
         negotiate_context.unpack(data[:data_length + 8])
-        return negotiate_context, data[8 + data_length:]
+        padded_size = data_length % 8
+        if padded_size != 0:
+            padded_size = 8 - padded_size
+
+        return negotiate_context, data[8 + data_length + padded_size:]
 
 
 class SMB2NegotiateContextRequest(Structure):
@@ -349,11 +340,14 @@ class SMB2NegotiateContextRequest(Structure):
 
     def _data_structure_type(self, structure):
         con_type = structure['context_type'].get_value()
-        if con_type == \
-                NegotiateContextType.SMB2_PREAUTH_INTEGRITY_CAPABILITIES:
+        if con_type == NegotiateContextType.SMB2_PREAUTH_INTEGRITY_CAPABILITIES:
             return SMB2PreauthIntegrityCapabilities
         elif con_type == NegotiateContextType.SMB2_ENCRYPTION_CAPABILITIES:
             return SMB2EncryptionCapabilities
+        elif con_type == NegotiateContextType.SMB2_NETNAME_NEGOTIATE_CONTEXT_ID:
+            return SMB2NetnameNegotiateContextId
+        elif con_type == NegotiateContextType.SMB2_SIGNING_CAPABILITIES:
+            return SMB2SigningCapabilities
 
     def _padding_size(self, structure):
         data_size = len(structure['data'])
@@ -418,6 +412,48 @@ class SMB2EncryptionCapabilities(Structure):
         super(SMB2EncryptionCapabilities, self).__init__()
 
 
+class SMB2NetnameNegotiateContextId(Structure):
+    """
+    [MS-SMB2] 2.2.3.1.4 SMB2_NETNAME_NEGOTIATE_CONTEXT_ID
+
+    https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-smb2/ca6726bd-b9cf-43d9-b0bc-d127d3c993b3
+
+    The SMB2_NETNAME_NEGOTIATE_CONTEXT_ID context is specified in an SMB2
+    NEGOTIATE request to indicate the server name the client connects to.
+    """
+
+    def __init__(self):
+        self.fields = OrderedDict([
+            ('net_name', TextField()),
+        ])
+        super().__init__()
+
+
+class SMB2SigningCapabilities(Structure):
+    """
+    [MS-SMB2] 2.2.3.1.7 SMB2_SIGNING_CAPABILITIES
+
+    https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-smb2/cb9b5d66-b6be-4d18-aa66-8784a871cc10
+
+    The SMB2_SIGNING_CAPABILITIES context is specified in an SMB2 NEGOTIATE
+    request by the client to indicate which signing algorithms the client supports.
+    """
+
+    def __init__(self):
+        self.fields = OrderedDict([
+            ('signing_algorithm_count', IntField(
+                size=2,
+                default=lambda s: len(s['signing_algorithms'].get_value()),
+            )),
+            ('signing_algorithms', ListField(
+                size=lambda s: s['signing_algorithm_count'].get_value() * 2,
+                list_count=lambda s: s['signing_algorithm_count'].get_value(),
+                list_type=EnumField(size=2, enum_type=SigningAlgorithms),
+            )),
+        ])
+        super().__init__()
+
+
 class SMB2NegotiateResponse(Structure):
     """
     [MS-SMB2] v53.0 2017-09-15
@@ -696,10 +732,13 @@ class Connection(object):
         # The cipher object that was negotiated
         self.cipher_id = None
 
+        # The signing algorithm that was negotiated
+        self.signing_algorithm_id = None
+
         # Keep track of the message processing thread's potential traceback that it may raise.
         self._t_exc = None
 
-    def connect(self, dialect=None, timeout=60):
+    def connect(self, dialect=None, timeout=60, preferred_encryption_algos=None, preferred_signing_algos=None):
         """
         Will connect to the target server and negotiate the capabilities
         with the client. Once setup, the client MUST call the disconnect()
@@ -707,11 +746,35 @@ class Connection(object):
         various connection properties that denote the capabilities of the
         server.
 
+        If no preferred encryption or signing algorithms are specified then
+        all algorithms are offered during negotiation. Older dialects may not
+        be offered if a custom encryption or signing algorithm list is
+        specified without the algorithm required by that dialect.
+
+        By default the following encryption algorithms are used:
+
+            AES_128_GCM
+            AES_128_CCM (required for SMB 3.0.x)
+            AES_256_GCM
+            AES_256_CCM
+
+        By default the following signing algorithms are used:
+
+            AES_GMAC
+            AES_CMAC (required for SMB 3.0.x)
+            HMAC_SHA256 (required for SMB 2.x)
+
         :param dialect: If specified, forces the dialect that is negotiated
             with the server, if not set, then the newest dialect supported by
             the server is used up to SMB 3.1.1
         :param timeout: The timeout in seconds to wait for the initial
             negotiation process to complete
+        :param preferred_encryption_algos: A list of encryption algorithm ids
+            in priority order from highest to lowest. See :class:`Ciphers` for
+            a list of known identifiers.
+        :param preferred_signing_algos: A list of signing algorithm ids in
+            priority order from highest to lowest.
+            See :class:`SigningAlgorithms` for a list of known identifiers.
         """
         log.info("Setting up transport connection")
         self.transport = Tcp(self.server_name, self.port, timeout)
@@ -722,7 +785,18 @@ class Connection(object):
         t_worker.start()
 
         log.info("Starting negotiation with SMB server")
-        smb_response = self._send_smb2_negotiate(dialect, timeout)
+        enc_algos = preferred_encryption_algos or [
+            Ciphers.AES_128_GCM,
+            Ciphers.AES_128_CCM,
+            Ciphers.AES_256_GCM,
+            Ciphers.AES_256_CCM,
+        ]
+        sign_algos = preferred_signing_algos or [
+            SigningAlgorithms.AES_GMAC,
+            SigningAlgorithms.AES_CMAC,
+            SigningAlgorithms.HMAC_SHA256,
+        ]
+        smb_response = self._send_smb2_negotiate(dialect, timeout, enc_algos, sign_algos)
         log.info("Negotiated dialect: %s"
                  % str(smb_response['dialect_revision']))
         self.dialect = smb_response['dialect_revision'].get_value()
@@ -767,15 +841,17 @@ class Connection(object):
         # SMB 3.1
         if self.dialect >= Dialects.SMB_3_1_1:
             for context in smb_response['negotiate_context_list']:
-                if context['context_type'].get_value() == \
-                        NegotiateContextType.SMB2_ENCRYPTION_CAPABILITIES:
-                    cipher_id = context['data']['ciphers'][0]
-                    self.cipher_id = Ciphers.get_cipher(cipher_id)
+                context_type = context["context_type"].get_value()
+
+                if context_type == NegotiateContextType.SMB2_ENCRYPTION_CAPABILITIES:
+                    self.cipher_id = context['data']['ciphers'][0]
                     self.supports_encryption = self.cipher_id != 0
-                else:
-                    hash_id = context['data']['hash_algorithms'][0]
-                    self.preauth_integrity_hash_id = \
-                        HashAlgorithms.get_algorithm(hash_id)
+
+                elif context_type == NegotiateContextType.SMB2_PREAUTH_INTEGRITY_CAPABILITIES:
+                    self.preauth_integrity_hash_id = context['data']['hash_algorithms'][0]
+
+                elif context_type == NegotiateContextType.SMB2_SIGNING_CAPABILITIES:
+                    self.signing_algorithm_id = context['data']['signing_algorithms'][0]
 
     def disconnect(self, close=True):
         """
@@ -989,7 +1065,8 @@ class Connection(object):
         if session is None:
             raise SMBException("Failed to find session %s for message verification" % session_id)
 
-        expected = self._generate_signature(header.pack(), session.signing_key)
+        expected = self._generate_signature(header.pack(), session.signing_key, message_id,
+                                            flags.has_flag(Smb2Flags.SMB2_FLAGS_SERVER_TO_REDIR), command)
         actual = header['signature'].get_value()
         if actual != expected:
             raise SMBException("Server message signature could not be verified: %s != %s"
@@ -1068,7 +1145,7 @@ class Connection(object):
             if force_signature or (session and session.signing_required and session.signing_key):
                 header['flags'].set_flag(Smb2Flags.SMB2_FLAGS_SIGNED)
                 b_header = header.pack() + padding
-                signature = self._generate_signature(b_header, session.signing_key)
+                signature = self._generate_signature(b_header, session.signing_key, current_id, False, message.COMMAND)
 
                 # To save on unpacking and re-packing, manually adjust the signature and update the request object for
                 # back-referencing.
@@ -1211,13 +1288,37 @@ class Connection(object):
             for request in self.outstanding_requests.values():
                 request.response_event.set()
 
-    def _generate_signature(self, b_header, signing_key):
+    def _generate_signature(self, b_header, signing_key, message_id, response, command):
         b_header = b_header[:48] + (b"\x00" * 16) + b_header[64:]
 
-        if self.dialect >= Dialects.SMB_3_0_0:
+        if self.dialect >= Dialects.SMB_3_1_1 and self.signing_algorithm_id is not None:
+            sign_id = self.signing_algorithm_id
+
+        elif self.dialect >= Dialects.SMB_3_0_0:
+            sign_id = SigningAlgorithms.AES_CMAC
+
+        else:
+            sign_id = SigningAlgorithms.HMAC_SHA256
+
+        if sign_id == SigningAlgorithms.AES_GMAC:
+            message_info = 0
+            if response:
+                message_info |= 1
+
+            if command == Commands.SMB2_CANCEL:
+                message_info |= 2
+
+            nonce = b"".join([
+                message_id.to_bytes(8, byteorder="little"),
+                message_info.to_bytes(4, byteorder="little"),
+            ])
+            signature = aead.AESGCM(signing_key).encrypt(nonce, b"", b_header)
+
+        elif sign_id == SigningAlgorithms.AES_CMAC:
             c = cmac.CMAC(algorithms.AES(signing_key), backend=default_backend())
             c.update(b_header)
             signature = c.finalize()
+
         else:
             hmac_algo = hmac.new(signing_key, msg=b_header, digestmod=hashlib.sha256)
             signature = hmac_algo.digest()[:16]
@@ -1231,13 +1332,16 @@ class Connection(object):
 
         encryption_key = session.encryption_key
         if self.dialect >= Dialects.SMB_3_1_1:
-            cipher = self.cipher_id
+            cipher_id = self.cipher_id
         else:
-            cipher = Ciphers.get_cipher(Ciphers.AES_128_CCM)
-        if cipher == aead.AESGCM:
+            cipher_id = Ciphers.AES_128_CCM
+
+        if cipher_id in [Ciphers.AES_128_GCM, Ciphers.AES_256_GCM]:
+            cipher = aead.AESGCM
             nonce = os.urandom(12)
             header['nonce'] = nonce + (b"\x00" * 4)
         else:
+            cipher = aead.AESCCM
             nonce = os.urandom(11)
             header['nonce'] = nonce + (b"\x00" * 5)
 
@@ -1263,13 +1367,18 @@ class Connection(object):
             raise SMBException(error_msg)
 
         if self.dialect >= Dialects.SMB_3_1_1:
-            cipher = self.cipher_id
+            cipher_id = self.cipher_id
         else:
-            cipher = Ciphers.get_cipher(Ciphers.AES_128_CCM)
+            cipher_id = Ciphers.AES_128_CCM
 
-        nonce_length = 12 if cipher == aead.AESGCM else 11
-        nonce = message['nonce'].get_value()[:nonce_length]
+        if cipher_id in [Ciphers.AES_128_GCM, Ciphers.AES_256_GCM]:
+            cipher = aead.AESGCM
+            nonce_length = 12
+        else:
+            cipher = aead.AESCCM
+            nonce_length = 11
 
+        nonce = message['nonce'].get_value()[:nonce_length]
         signature = message['signature'].get_value()
         enc_message = message['data'].get_value() + signature
 
@@ -1277,29 +1386,42 @@ class Connection(object):
         dec_message = c.decrypt(nonce, enc_message, message.pack()[20:52])
         return dec_message
 
-    def _send_smb2_negotiate(self, dialect, timeout):
+    def _send_smb2_negotiate(self, dialect, timeout, encryption_algorithms, signing_algorithms):
         self.salt = os.urandom(32)
 
         if dialect is None:
             neg_req = SMB3NegotiateRequest()
-            self.negotiated_dialects = [
+            negotiated_dialects = [
                 Dialects.SMB_2_0_2,
                 Dialects.SMB_2_1_0,
                 Dialects.SMB_3_0_0,
                 Dialects.SMB_3_0_2,
                 Dialects.SMB_3_1_1
             ]
-            highest_dialect = Dialects.SMB_3_1_1
+
+            if SigningAlgorithms.HMAC_SHA256 not in signing_algorithms:
+                if Dialects.SMB_2_0_2 in negotiated_dialects:
+                    negotiated_dialects.remove(Dialects.SMB_2_0_2)
+                if Dialects.SMB_2_1_0 in negotiated_dialects:
+                    negotiated_dialects.remove(Dialects.SMB_2_1_0)
+
+            if (
+                SigningAlgorithms.AES_CMAC not in signing_algorithms or
+                Ciphers.AES_128_CCM not in encryption_algorithms
+            ):
+                if Dialects.SMB_3_0_0 in negotiated_dialects:
+                    negotiated_dialects.remove(Dialects.SMB_3_0_0)
+                if Dialects.SMB_3_0_2 in negotiated_dialects:
+                    negotiated_dialects.remove(Dialects.SMB_3_0_2)
         else:
             if dialect >= Dialects.SMB_3_1_1:
                 neg_req = SMB3NegotiateRequest()
             else:
                 neg_req = SMB2NegotiateRequest()
-            self.negotiated_dialects = [
-                dialect
-            ]
-            highest_dialect = dialect
-        neg_req['dialects'] = self.negotiated_dialects
+            negotiated_dialects = [dialect]
+
+        highest_dialect = sorted(negotiated_dialects)[-1]
+        self.negotiated_dialects = neg_req['dialects'] = negotiated_dialects
         log.info("Negotiating with SMB2 protocol with highest client dialect "
                  "of: %s" % [dialect for dialect, v in vars(Dialects).items()
                              if v == highest_dialect][0])
@@ -1340,17 +1462,32 @@ class Connection(object):
             enc_cap['context_type'] = \
                 NegotiateContextType.SMB2_ENCRYPTION_CAPABILITIES
             enc_cap['data'] = SMB2EncryptionCapabilities()
-            supported_ciphers = Ciphers.get_supported_ciphers()
+            supported_ciphers = encryption_algorithms
             enc_cap['data']['ciphers'] = supported_ciphers
+            log.debug("Adding encryption capabilities of AES128|256 GCM and "
+                      "AES128|256 CCM to negotiate request")
+
+            netname_id = SMB2NegotiateContextRequest()
+            netname_id['context_type'] = NegotiateContextType.SMB2_NETNAME_NEGOTIATE_CONTEXT_ID
+            netname_id['data'] = SMB2NetnameNegotiateContextId()
+            netname_id['data']['net_name'] = self.server_name
+            log.debug(f"Adding netname context id of {self.server_name} to negotiate request")
+
+            signing_cap = SMB2NegotiateContextRequest()
+            signing_cap['context_type'] = NegotiateContextType.SMB2_SIGNING_CAPABILITIES
+            signing_cap['data'] = SMB2SigningCapabilities()
+            signing_cap['data']['signing_algorithms'] = signing_algorithms
+            log.debug("Adding signing algorithms AES_GMAC, AES_CMAC, and HMAC_SHA256 to negotiate request")
+
             # remove extra padding for last list entry
-            enc_cap['padding'].size = 0
-            enc_cap['padding'] = b""
-            log.debug("Adding encryption capabilities of AES128 GCM and "
-                      "AES128 CCM to negotiate request")
+            signing_cap['padding'].size = 0
+            signing_cap['padding'] = b""
 
             neg_req['negotiate_context_list'] = [
                 int_cap,
-                enc_cap
+                enc_cap,
+                netname_id,
+                signing_cap,
             ]
 
         log.info("Sending SMB2 Negotiate message")
diff --git a/smbprotocol/session.py b/smbprotocol/session.py
index cba9cc6..d6d33bc 100644
--- a/smbprotocol/session.py
+++ b/smbprotocol/session.py
@@ -2,6 +2,7 @@
 # Copyright: (c) 2019, Jordan Borean (@jborean93) <jborean93@gmail.com>
 # MIT License (see LICENSE or https://opensource.org/licenses/MIT)
 
+import hashlib
 import logging
 import random
 import spnego
@@ -30,6 +31,7 @@ from smbprotocol import (
 
 from smbprotocol.connection import (
     Capabilities,
+    Ciphers,
     SecurityMode,
 )
 
@@ -251,6 +253,7 @@ class Session(object):
         self.decryption_key = None
         self.signing_key = None
         self.application_key = None
+        self.full_session_key = None
 
         # SMB 3.1.1+
         # Preauth integrity value computed for the exhange of SMB2
@@ -321,19 +324,31 @@ class Session(object):
         self.connection.session_table[self.session_id] = self.connection.preauth_session_table.pop(self.session_id)
 
         # session_key is the first 16 bytes, padded 0 if less than 16
-        self.session_key = context.session_key[:16].ljust(16, b"\x00")
+        self.full_session_key = context.session_key
+        self.session_key = self.full_session_key[:16].ljust(16, b"\x00")
 
         if self.connection.dialect >= Dialects.SMB_3_1_1:
             preauth_hash = b"\x00" * 64
-            hash_al = self.connection.preauth_integrity_hash_id
             for hash_list in [self.connection.preauth_integrity_hash_value, self.preauth_integrity_hash_value]:
                 for message in hash_list:
-                    preauth_hash = hash_al(preauth_hash + message).digest()
+                    # Technically the algo is based on preauth_integrity_hash_id but we only support the 1
+                    preauth_hash = hashlib.sha512(preauth_hash + message).digest()
 
             self.signing_key = self._smb3kdf(self.session_key, b"SMBSigningKey\x00", preauth_hash)
             self.application_key = self._smb3kdf(self.session_key, b"SMBAppKey\x00", preauth_hash)
-            self.encryption_key = self._smb3kdf(self.session_key, b"SMBC2SCipherKey\x00", preauth_hash)
-            self.decryption_key = self._smb3kdf(self.session_key, b"SMBS2CCipherKey\x00", preauth_hash)
+
+            if self.connection.cipher_id in [
+                Ciphers.AES_256_CCM,
+                Ciphers.AES_256_GCM,
+            ]:
+                key_length = 32
+                key = self.full_session_key
+            else:
+                key_length = 16
+                key = self.session_key
+
+            self.encryption_key = self._smb3kdf(key, b"SMBC2SCipherKey\x00", preauth_hash, length=key_length)
+            self.decryption_key = self._smb3kdf(key, b"SMBS2CCipherKey\x00", preauth_hash, length=key_length)
 
         elif self.connection.dialect >= Dialects.SMB_3_0_0:
             self.signing_key = self._smb3kdf(self.session_key, b"SMB2AESCMAC\x00", b"SmbSign\x00")
@@ -405,7 +420,7 @@ class Session(object):
         self._connected = False
         del self.connection.session_table[self.session_id]
 
-    def _smb3kdf(self, ki, label, context):
+    def _smb3kdf(self, ki, label, context, length=16):
         """
         See SMB 3.x key derivation function
         https://blogs.msdn.microsoft.com/openspecification/2017/05/26/smb-2-and-smb-3-security-in-windows-10-the-anatomy-of-signing-and-cryptographic-keys/
@@ -413,13 +428,14 @@ class Session(object):
         :param ki: The session key is the KDK used as an input to the KDF
         :param label: The purpose of this derived key as bytes string
         :param context: The context information of this derived key as bytes
+        :param length: The length of the key to generate
         string
         :return: Key derived by the KDF as specified by [SP800-108] 5.1
         """
         kdf = KBKDFHMAC(
             algorithm=hashes.SHA256(),
             mode=Mode.CounterMode,
-            length=16,
+            length=length,
             rlen=4,
             llen=4,
             location=CounterLocation.BeforeFixed,
diff --git a/tests/test_connection.py b/tests/test_connection.py
index 9dea483..76f3fab 100644
--- a/tests/test_connection.py
+++ b/tests/test_connection.py
@@ -2,15 +2,10 @@
 # Copyright: (c) 2019, Jordan Borean (@jborean93) <jborean93@gmail.com>
 # MIT License (see LICENSE or https://opensource.org/licenses/MIT)
 
-import hashlib
 import os
 import pytest
 import uuid
 
-from cryptography.hazmat.primitives.ciphers import (
-    aead,
-)
-
 from datetime import (
     datetime,
 )
@@ -25,6 +20,8 @@ from smbprotocol.connection import (
     HashAlgorithms,
     NegotiateContextType,
     Request,
+    SMB2NetnameNegotiateContextId,
+    SMB2SigningCapabilities,
     SecurityMode,
     SMB2CancelRequest,
     SMB2EncryptionCapabilities,
@@ -35,6 +32,7 @@ from smbprotocol.connection import (
     SMB2PreauthIntegrityCapabilities,
     SMB2TransformHeader,
     SMB3NegotiateRequest,
+    SigningAlgorithms,
 )
 
 from smbprotocol.header import (
@@ -56,30 +54,6 @@ from smbprotocol.session import (
 )
 
 
-def test_valid_hash_algorithm():
-    expected = hashlib.sha512
-    actual = HashAlgorithms.get_algorithm(0x1)
-    assert actual == expected
-
-
-def test_invalid_hash_algorithm():
-    with pytest.raises(KeyError) as exc:
-        HashAlgorithms.get_algorithm(0x2)
-        assert False  # shouldn't be reached
-
-
-def test_valid_cipher():
-    expected = aead.AESCCM
-    actual = Ciphers.get_cipher(0x1)
-    assert actual == expected
-
-
-def test_invalid_cipher():
-    with pytest.raises(KeyError) as exc:
-        Ciphers.get_cipher(0x3)
-        assert False  # shouldn't be reached
-
-
 class TestSMB2NegotiateRequest(object):
 
     def test_create_message(self):
@@ -163,8 +137,15 @@ class TestSMB3NegotiateRequest(object):
         enc_cap = SMB2EncryptionCapabilities()
         enc_cap['ciphers'] = [Ciphers.AES_128_GCM]
         con_req['data'] = enc_cap
+
+        netname = SMB2NegotiateContextRequest()
+        netname['context_type'] = NegotiateContextType.SMB2_NETNAME_NEGOTIATE_CONTEXT_ID
+        netname['data'] = SMB2NetnameNegotiateContextId()
+        netname['data']['net_name'] = 'café'
+
         message['negotiate_context_list'] = [
-            con_req
+            con_req,
+            netname,
         ]
         expected = b"\x24\x00" \
                    b"\x05\x00" \
@@ -174,7 +155,7 @@ class TestSMB3NegotiateRequest(object):
                    b"\x33\x33\x33\x33\x33\x33\x33\x33" \
                    b"\x33\x33\x33\x33\x33\x33\x33\x33" \
                    b"\x70\x00\x00\x00" \
-                   b"\x01\x00" \
+                   b"\x02\x00" \
                    b"\x00\x00" \
                    b"\x02\x02" \
                    b"\x10\x02" \
@@ -184,9 +165,12 @@ class TestSMB3NegotiateRequest(object):
                    b"\x00\x00" \
                    b"\x02\x00\x04\x00\x00\x00\x00\x00" \
                    b"\x01\x00\x02\x00" \
-                   b"\x00\x00\x00\x00"
+                   b"\x00\x00\x00\x00" \
+                   b"\x05\x00\x08\x00\x00\x00\x00\x00" \
+                   b"\x63\x00\x61\x00\x66\x00\xe9\x00"
+
         actual = message.pack()
-        assert len(message) == 64
+        assert len(message) == 80
         assert actual == expected
 
     def test_create_message_one_dialect(self):
@@ -236,7 +220,7 @@ class TestSMB3NegotiateRequest(object):
                b"\x33\x33\x33\x33\x33\x33\x33\x33" \
                b"\x33\x33\x33\x33\x33\x33\x33\x33" \
                b"\x70\x00\x00\x00" \
-               b"\x01\x00" \
+               b"\x02\x00" \
                b"\x00\x00" \
                b"\x02\x02" \
                b"\x10\x02" \
@@ -246,9 +230,12 @@ class TestSMB3NegotiateRequest(object):
                b"\x00\x00" \
                b"\x02\x00\x04\x00\x00\x00\x00\x00" \
                b"\x01\x00\x02\x00" \
-               b"\x00\x00\x00\x00"
+               b"\x00\x00\x00\x00" \
+               b"\x05\x00\x08\x00\x00\x00\x00\x00" \
+               b"\x63\x00\x61\x00\x66\x00\xe9\x00"
+
         actual.unpack(data)
-        assert len(actual) == 60
+        assert len(actual) == 76
         assert actual['structure_size'].get_value() == 36
         assert actual['dialect_count'].get_value() == 5
         assert actual['security_mode'].get_value() == \
@@ -258,7 +245,7 @@ class TestSMB3NegotiateRequest(object):
         assert actual['client_guid'].get_value() == \
             uuid.UUID(bytes=b"\x33" * 16)
         assert actual['negotiate_context_offset'].get_value() == 112
-        assert actual['negotiate_context_count'].get_value() == 1
+        assert actual['negotiate_context_count'].get_value() == 2
         assert actual['reserved2'].get_value() == 0
         assert actual['dialects'].get_value() == [
             Dialects.SMB_2_0_2,
@@ -269,7 +256,7 @@ class TestSMB3NegotiateRequest(object):
         ]
         assert actual['padding'].get_value() == b"\x00\x00"
 
-        assert len(actual['negotiate_context_list'].get_value()) == 1
+        assert len(actual['negotiate_context_list'].get_value()) == 2
         neg_con = actual['negotiate_context_list'][0]
         assert isinstance(neg_con, SMB2NegotiateContextRequest)
         assert len(neg_con) == 12
@@ -282,6 +269,15 @@ class TestSMB3NegotiateRequest(object):
         assert neg_con['data']['cipher_count'].get_value() == 1
         assert neg_con['data']['ciphers'].get_value() == [Ciphers.AES_128_GCM]
 
+        net_name = actual['negotiate_context_list'][1]
+        assert isinstance(net_name, SMB2NegotiateContextRequest)
+        assert len(net_name) == 16
+        assert net_name['context_type'].get_value() == NegotiateContextType.SMB2_NETNAME_NEGOTIATE_CONTEXT_ID
+        assert net_name['data_length'].get_value() == 8
+        assert net_name['reserved'].get_value() == 0
+        assert isinstance(net_name['data'].get_value(), SMB2NetnameNegotiateContextId)
+        assert net_name['data']['net_name'].get_value() == 'café'
+
 
 class TestSMB2NegotiateContextRequest(object):
 
@@ -323,14 +319,14 @@ class TestSMB2NegotiateContextRequest(object):
 
     def test_parse_message_invalid_context_type(self):
         actual = SMB2NegotiateContextRequest()
-        data = b"\x03\x00" \
+        data = b"\xFF\xFF" \
                b"\x04\x00" \
                b"\x00\x00\x00\x00" \
                b"\x01\x00" \
                b"\x02\x00"
         with pytest.raises(Exception) as exc:
             actual.unpack(data)
-        assert str(exc.value) == "Enum value 3 does not exist in enum type " \
+        assert str(exc.value) == "Enum value 65535 does not exist in enum type " \
                                  "<class 'smbprotocol.connection." \
                                  "NegotiateContextType'>"
 
@@ -398,6 +394,58 @@ class TestSMB2EncryptionCapabilities(object):
         ]
 
 
+class TestSMB2NetnameNegotiateContextId:
+
+    def test_create_message(self):
+        message = SMB2NetnameNegotiateContextId()
+        message["net_name"] = "hostname"
+        expected = b"\x68\x00\x6F\x00\x73\x00\x74\x00" \
+                   b"\x6E\x00\x61\x00\x6D\x00\x65\x00"
+        actual = message.pack()
+        assert len(message) == 16
+        assert actual == expected
+
+    def test_parse_message(self):
+        actual = SMB2NetnameNegotiateContextId()
+        data = b"\x68\x00\x6F\x00\x73\x00\x74\x00" \
+               b"\x6E\x00\x61\x00\x6D\x00\x65\x00"
+        actual.unpack(data)
+        assert len(actual) == 16
+        assert actual["net_name"].get_value() == "hostname"
+
+
+class TestSMB2SigningCapabilities:
+
+    def test_create_message(self):
+        message = SMB2SigningCapabilities()
+        message["signing_algorithms"] = [
+            SigningAlgorithms.AES_GMAC,
+            SigningAlgorithms.AES_CMAC,
+            SigningAlgorithms.HMAC_SHA256,
+        ]
+        expected = b"\x03\x00" \
+                   b"\x02\x00" \
+                   b"\x01\x00" \
+                   b"\x00\x00"
+        actual = message.pack()
+        assert len(message) == 8
+        assert actual == expected
+
+    def test_parse_message(self):
+        actual = SMB2SigningCapabilities()
+        data = b"\x03\x00" \
+               b"\x02\x00" \
+               b"\x01\x00" \
+               b"\x00\x00"
+        actual.unpack(data)
+        assert len(actual) == 8
+        assert actual["signing_algorithms"].get_value() == [
+            SigningAlgorithms.AES_GMAC,
+            SigningAlgorithms.AES_CMAC,
+            SigningAlgorithms.HMAC_SHA256,
+        ]
+
+
 class TestSMB2NegotiateResponse(object):
 
     def test_create_message(self):
@@ -1079,51 +1127,61 @@ class TestConnection(object):
         finally:
             connection.disconnect(True)
 
-    def test_encrypt_ccm(self, monkeypatch):
+    @pytest.mark.parametrize('key, cipher, signature, data', [
+        (b"\xff" * 16, Ciphers.AES_128_CCM, b"\xc8\x73\x0c\x9b\xa7\xe5\x9f\x1c\xfd\x37\x51\xa1\x95\xf2\xb3\xac",
+         b"\x21\x91\xe3\x0e"),
+        (b"\xff" * 32, Ciphers.AES_256_CCM, b"\x3E\xFB\x47\x97\x51\x8A\xAB\x05\xC5\x48\xA7\xFC\x20\x74\xF5\x93",
+         b"\x2F\x58\x41\xD7"),
+    ], ids=['AES128_CCM', 'AES256_CCM'])
+    def test_encrypt_ccm(self, key, cipher, signature, data, monkeypatch):
         def mockurandom(length):
             return b"\xff" * length
         monkeypatch.setattr(os, 'urandom', mockurandom)
 
         connection = Connection(uuid.uuid4(), "server", 445)
         connection.dialect = Dialects.SMB_3_1_1
-        connection.cipher_id = Ciphers.get_cipher(Ciphers.AES_128_CCM)
+        connection.cipher_id = cipher
         session = Session(connection, "user", "pass")
         session.session_id = 1
-        session.encryption_key = b"\xff" * 16
+        session.encryption_key = key
 
         expected = SMB2TransformHeader()
-        expected['signature'] = b"\xc8\x73\x0c\x9b\xa7\xe5\x9f\x1c" \
-            b"\xfd\x37\x51\xa1\x95\xf2\xb3\xac"
+        expected['signature'] = signature
         expected['nonce'] = b"\xff" * 11 + b"\x00" * 5
         expected['original_message_size'] = 4
         expected['flags'] = 1
         expected['session_id'] = 1
-        expected['data'] = b"\x21\x91\xe3\x0e"
+        expected['data'] = data
 
         actual = connection._encrypt(b"\x01\x02\x03\x04", session)
         assert isinstance(actual, SMB2TransformHeader)
         assert actual.pack() == expected.pack()
 
-    def test_encrypt_gcm(self, monkeypatch):
+    @pytest.mark.parametrize('key, cipher, signature, data', [
+        (b"\xff" * 16, Ciphers.AES_128_GCM, b"\x39\xd8\x32\x34\xd7\x53\xd0\x8e\xc0\xfc\xbe\x33\x01\x5f\x19\xbd",
+         b"\xda\x26\x57\x33"),
+        (b"\xff" * 32, Ciphers.AES_256_GCM, b"\x45\xE5\xB7\x23\x05\x2E\xCA\xD0\x1E\xEF\xAD\x6F\x04\x87\xE3\x2D",
+         b"\xBC\x39\xBD\x81"),
+    ], ids=['AES128_CCM', 'AES256_CCM'])
+    def test_encrypt_gcm(self, key, cipher, signature, data, monkeypatch):
         def mockurandom(length):
             return b"\xff" * length
         monkeypatch.setattr(os, 'urandom', mockurandom)
 
         connection = Connection(uuid.uuid4(), "server", 445)
         connection.dialect = Dialects.SMB_3_1_1
-        connection.cipher_id = Ciphers.get_cipher(Ciphers.AES_128_GCM)
+        connection.cipher_id = cipher
         session = Session(connection, "user", "pass")
         session.session_id = 1
-        session.encryption_key = b"\xff" * 16
+        session.encryption_key = key
 
         expected = SMB2TransformHeader()
-        expected['signature'] = b"\x39\xd8\x32\x34\xd7\x53\xd0\x8e" \
-            b"\xc0\xfc\xbe\x33\x01\x5f\x19\xbd"
+        expected['signature'] = signature
         expected['nonce'] = b"\xff" * 12 + b"\x00" * 4
         expected['original_message_size'] = 4
         expected['flags'] = 1
         expected['session_id'] = 1
-        expected['data'] = b"\xda\x26\x57\x33"
+        expected['data'] = data
 
         actual = connection._encrypt(b"\x01\x02\x03\x04", session)
         assert isinstance(actual, SMB2TransformHeader)
diff --git a/tests/test_smbclient_os.py b/tests/test_smbclient_os.py
index 18f6af8..9b4eb57 100644
--- a/tests/test_smbclient_os.py
+++ b/tests/test_smbclient_os.py
@@ -826,8 +826,6 @@ def test_open_file_unbuffered_text_file(smb_share):
         smbclient.open_file("%s\\file.txt" % smb_share, mode='w', buffering=0)
 
 
-@pytest.mark.skipif(os.name != "nt" and not os.environ.get('SMB_FORCE', False),
-                    reason="Bug on latest Samba https://bugzilla.samba.org/show_bug.cgi?id=14877")
 def test_open_file_with_ads(smb_share):
     filename = "%s\\file.txt" % smb_share
     with smbclient.open_file(filename, mode='w') as fd:
diff --git a/tests/test_tree.py b/tests/test_tree.py
index 6ec25aa..948ebce 100644
--- a/tests/test_tree.py
+++ b/tests/test_tree.py
@@ -10,7 +10,9 @@ from smbprotocol import (
 )
 
 from smbprotocol.connection import (
+    Ciphers,
     Connection,
+    SigningAlgorithms,
 )
 
 from smbprotocol.exceptions import (
@@ -266,3 +268,50 @@ class TestTreeConnect(object):
         finally:
             connection.disconnect(True)
             tree.disconnect()  # test that disconnect can be run mutliple times
+
+    @pytest.mark.parametrize('cipher', [
+        Ciphers.AES_128_CCM,
+        Ciphers.AES_128_GCM,
+        Ciphers.AES_256_CCM,
+        Ciphers.AES_256_GCM,
+    ], ids=['AES_128_CCM', 'AES_128_GCM', 'AES_256_CCM', 'AES_256_GCM'])
+    def test_encryption(self, cipher, smb_real):
+        connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3])
+        connection.connect(preferred_encryption_algos=[cipher])
+
+        try:
+            if connection.cipher_id == 0:
+                pytest.skip("Server did not support encryption requested")
+
+            assert connection.cipher_id == cipher
+
+            session = Session(connection, smb_real[0], smb_real[1])
+            tree = TreeConnect(session, smb_real[4])
+            session.connect()
+            tree.connect()
+
+        finally:
+            connection.disconnect(True)
+
+    @pytest.mark.parametrize('algo', [
+        SigningAlgorithms.AES_GMAC,
+        SigningAlgorithms.AES_CMAC,
+        SigningAlgorithms.HMAC_SHA256,
+    ], ids=['AES_GMAC', 'AES_CMAC', 'HMAC_SHA256'])
+    def test_signing(self, algo, smb_real):
+        connection = Connection(uuid.uuid4(), smb_real[2], smb_real[3])
+        connection.connect(preferred_signing_algos=[algo])
+
+        try:
+            if connection.signing_algorithm_id is None:
+                pytest.skip("Server did not support signing algo requested")
+
+            assert connection.signing_algorithm_id == algo
+
+            session = Session(connection, smb_real[0], smb_real[1], require_encryption=False)
+            tree = TreeConnect(session, smb_real[4])
+            session.connect()
+            tree.connect()
+
+        finally:
+            connection.disconnect(True)
-- 
GitLab