From cd3e8267d6836418d6fa3aa50dbc8d473e5f2afd Mon Sep 17 00:00:00 2001
From: Andreas Ziegler <andreas.ziegler@fau.de>
Date: Mon, 7 Feb 2022 12:40:01 +0100
Subject: [PATCH] library, store, calls: partly handle dlsym

If we detect a call to a function containing 'dlsym' in its name,
try to extract the target string from the underlying file and
store this function name as referenced from the currently
disassembled function. This way, the transitive call selection
can later pick this edge up and mark the target functin as used.

Additionally, extract the target strings for 'dlopen'-named
functions in order to allow the identification of dynamically
loaded libraries.

This currently only works on x86_64.
---
 librarytrader/interface_calls.py | 88 ++++++++++++++++++++++++++++----
 librarytrader/library.py         |  5 ++
 librarytrader/librarystore.py    | 54 ++++++++++++++------
 3 files changed, 122 insertions(+), 25 deletions(-)

diff --git a/librarytrader/interface_calls.py b/librarytrader/interface_calls.py
index 482bb2d..08c2ee6 100644
--- a/librarytrader/interface_calls.py
+++ b/librarytrader/interface_calls.py
@@ -30,6 +30,7 @@ import time
 
 import capstone
 from elftools.common.exceptions import ELFError
+from elftools.common.utils import parse_cstring_from_stream
 
 # In order to be able to use librarytrader from git without having installed it,
 # add top level directory to PYTHONPATH
@@ -108,7 +109,8 @@ def find_calls_from_objdump(library, disas):
             elif target in library.local_functions:
                 calls_to_locals.add(target)
 
-    return (calls_to_exports, calls_to_imports, calls_to_locals, {}, {}, {}, {})
+    return (calls_to_exports, calls_to_imports, calls_to_locals, {}, {}, {},
+            {}, {}, {})
 
 def disassemble_capstone(library, start, length, cs_obj):
     disassembly = []
@@ -130,6 +132,30 @@ def disassemble_capstone(library, start, length, cs_obj):
 
     return (disas, find_calls_from_capstone)
 
+def _locate_parameter(library, disas, start_idx, target_register, mem_tag):
+    retval = None
+    idx = start_idx
+    while idx > 0:
+        idx -= 1
+        earlier_insn = disas[idx]
+        read, written = earlier_insn.regs_access()
+        if target_register not in written:
+            continue
+        elif earlier_insn.id != capstone.x86_const.X86_INS_LEA:
+            break
+        # Here we know it was a <lea ..., %rsi>, and these
+        # accesses will mostly be RIP-relative on x86_64 so we only support
+        # <lea xxx(%rip), %rsi> for now.
+        to, val = list(earlier_insn.operands)
+        if val.type == mem_tag and val.value.mem.base == capstone.x86.X86_REG_RIP:
+            stroff = earlier_insn.address + val.value.mem.disp + earlier_insn.size
+            # Does this need library._elffile.address_offsets()?
+            # In mariadb's glibc, .rodata is 1:1 mapped...
+            strval = parse_cstring_from_stream(library.fd, stroff)
+            retval = strval.decode('utf-8')
+            break
+    return retval
+
 def find_calls_from_capstone(library, disas):
     calls_to_exports = set()
     calls_to_imports = set()
@@ -138,6 +164,8 @@ def find_calls_from_capstone(library, disas):
     imported_object_refs = set()
     exported_object_refs = set()
     local_object_refs = set()
+    dlsym_refs = set()
+    dlopen_refs = set()
     if library.is_aarch64():
         call_group = capstone.arm64_const.ARM64_GRP_CALL
         jump_group = capstone.arm64_const.ARM64_GRP_JUMP
@@ -155,7 +183,7 @@ def find_calls_from_capstone(library, disas):
     thunk_reg = None
     thunk_val = None
 
-    for instr in disas:
+    for idx, instr in enumerate(disas):
         if instr.group(call_group) or instr.group(jump_group):
             operand = instr.operands[-1]
             if operand.type == imm_tag:
@@ -167,10 +195,42 @@ def find_calls_from_capstone(library, disas):
                 indirect_calls.add((instr.address, '{} {}'.format(instr.mnemonic, instr.op_str)))
                 continue
             if target in library.exported_addrs:
+                for name in library.exported_addrs[target]:
+                    if 'dlsym' in name:
+                        logging.debug('%s: call to %s at offset %x',
+                                      library.fullname, name, instr.address)
+                        param = _locate_parameter(library, disas, idx,
+                                                  capstone.x86_const.X86_REG_RSI,
+                                                  mem_tag)
+                        if param:
+                            dlsym_refs.add(param)
+                        break
+                    elif 'dlopen' in name:
+                        param = _locate_parameter(library, disas, idx,
+                                                capstone.x86_const.X86_REG_RDI,
+                                                mem_tag)
+                        if param:
+                            dlopen_refs.add(param)
+                        break
                 calls_to_exports.add(target)
             elif target in library.exports_plt:
                 calls_to_exports.add(library.exports_plt[target])
             elif target in library.imports_plt:
+                if 'dlsym' in library.imports_plt[target]:
+                    logging.debug('%s: call to imported %s at offset %x',
+                                  library.fullname, library.imports_plt[target],
+                                  instr.address)
+                    param = _locate_parameter(library, disas, idx,
+                                              capstone.x86_const.X86_REG_RSI,
+                                              mem_tag)
+                    if param:
+                        dlsym_refs.add(param)
+                elif 'dlopen' in library.imports_plt[target]:
+                    param = _locate_parameter(library, disas, idx,
+                                              capstone.x86_const.X86_REG_RDI,
+                                              mem_tag)
+                    if param:
+                        dlopen_refs.add(param)
                 calls_to_imports.add(library.imports_plt[target])
             elif target in library.local_functions:
                 # Note: this might only work for gcc-compiled libraries, as
@@ -283,7 +343,8 @@ def find_calls_from_capstone(library, disas):
                     local_object_refs.add(addr)
 
     return (calls_to_exports, calls_to_imports, calls_to_locals, indirect_calls,
-            imported_object_refs, exported_object_refs, local_object_refs)
+            imported_object_refs, exported_object_refs, local_object_refs,
+            dlsym_refs, dlopen_refs)
 
 def resolve_calls_in_library(library, start, size, disas_function=disassemble_capstone):
     logging.debug('Processing %s:%x', library.fullname, start)
@@ -294,13 +355,15 @@ def resolve_calls_in_library(library, start, size, disas_function=disassemble_ca
     indir = {}
     disas, resolution_function = disas_function(library, start, size, cs_obj)
     calls_to_exports, calls_to_imports, calls_to_locals, indirect_calls, \
-        uses_of_imports, uses_of_exports, uses_of_locals = resolution_function(library, disas)
+        uses_of_imports, uses_of_exports, uses_of_locals, dlsym_refs, \
+        dlopen_refs = resolution_function(library, disas)
 
     indir[start] = indirect_calls
 
     after = time.time()
     return (calls_to_exports, calls_to_imports, calls_to_locals, indir,
-            uses_of_imports, uses_of_exports, uses_of_locals, (after - before))
+            uses_of_imports, uses_of_exports, uses_of_locals, dlsym_refs,
+            dlopen_refs, (after - before))
 
 def map_wrapper(input_tuple):
     path, start, size = input_tuple
@@ -312,14 +375,16 @@ def map_wrapper(input_tuple):
             lib.fd = open(lib.fullname, 'rb')
     except Exception as err:
         logging.error('%s: %s', lib.fullname, err)
-        return (None, -1, None, None, None, None, None, None, None, 0)
+        return (None, -1, None, None, None, None, None, None, None, None, None, 0)
 
     internal_calls, external_calls, local_calls, indirect_calls, \
-        imported_uses, exported_uses, local_uses, duration = resolve_calls_in_library(lib, start, size)
+        imported_uses, exported_uses, local_uses, dlsym_refs, \
+        dlopen_refs, duration = resolve_calls_in_library(lib, start, size)
     lib.fd.close()
     del lib.fd
-    return (lib.fullname, start, internal_calls, external_calls, local_calls, indirect_calls,
-            imported_uses, exported_uses, local_uses, duration)
+    return (lib.fullname, start, internal_calls, external_calls, local_calls,
+            indirect_calls, imported_uses, exported_uses, local_uses, dlsym_refs,
+            dlopen_refs, duration)
 
 def resolve_calls(store, n_procs=int(multiprocessing.cpu_count() * 1.5)):
     # Pass by path (-> threads have to reconstruct)
@@ -343,7 +408,8 @@ def resolve_calls(store, n_procs=int(multiprocessing.cpu_count() * 1.5)):
     indir = {}
     calls = 0
     for fullname, start, internal_calls, external_calls, local_calls, indirect_calls, \
-            imported_uses, exported_uses, local_uses, duration in result:
+            imported_uses, exported_uses, local_uses, dlsym_refs, dlopen_refs, \
+            duration in result:
         if not fullname:
             continue
         store[fullname].internal_calls[start].update(internal_calls)
@@ -356,6 +422,8 @@ def resolve_calls(store, n_procs=int(multiprocessing.cpu_count() * 1.5)):
             indir[fullname] = set()
         indir[fullname].update(indirect_calls)
         calls += len(internal_calls) + len(external_calls) + len(local_calls)
+        store[fullname].dlsym_refs[start].update(dlsym_refs)
+        store[fullname].dlopen_refs[start].update(dlopen_refs)
         store[fullname].total_disas_time += duration
 
     pool.join()
diff --git a/librarytrader/library.py b/librarytrader/library.py
index 065e287..42d79d1 100644
--- a/librarytrader/library.py
+++ b/librarytrader/library.py
@@ -127,6 +127,11 @@ class Library:
         self.export_object_refs = collections.defaultdict(set)
         self.local_object_refs = collections.defaultdict(set)
         self.import_object_refs = collections.defaultdict(set)
+        # Mapping of caller address -> names of symbols imported via a *dlsym*
+        # function. These can be located later.
+        self.dlsym_refs = collections.defaultdict(set)
+        # Mapping of caller functions -> names of libraries of *dlopen* calls.
+        self.dlopen_refs = collections.defaultdict(set)
 
         # external users of objects: address -> list of referencing library paths
         self.object_users = collections.OrderedDict()
diff --git a/librarytrader/librarystore.py b/librarytrader/librarystore.py
index 293ce02..d063f4e 100644
--- a/librarytrader/librarystore.py
+++ b/librarytrader/librarystore.py
@@ -344,6 +344,18 @@ class LibraryStore(BaseStore):
                                   library.fullname)
                     local_cache.add((dependent_function, target_lib))
 
+        # Add references to symbols imported via dlsym
+        # TODO: try to find symbols from libraries in dlopen_refs first?
+        if function in library.dlsym_refs:
+            for outgoing_ref in library.dlsym_refs[function]:
+                target_lib, target_addr = self._hard_search_for_symbol(library,
+                                                                       outgoing_ref)
+                if target_lib:
+                    logging.debug('%s: transitive %x -> %s/%x through dlsym',
+                                  library.fullname, function,
+                                  target_lib.fullname, target_addr)
+                    local_cache.add((target_addr, target_lib))
+
         self._callee_cache[libname][function] = local_cache
         return self._callee_cache[libname][function]
 
@@ -483,6 +495,23 @@ class LibraryStore(BaseStore):
                 return True
         return False
 
+    def _hard_search_for_symbol(self, from_library, function):
+        # Hack: search function in all libraries in the store
+        target_lib, target_addr = None, None
+        for other_lib in self.get_library_objects():
+            if function in other_lib.exported_names:
+                target = other_lib.exported_names
+            else:
+                to_search = [x.split('@@')[0] for x in other_lib.exported_names]
+                if function in to_search:
+                    target = {x.split('@@')[0] : val for x, val in other_lib.exported_names.items()}
+                else:
+                    continue
+                target_lib = other_lib
+                target_addr = target[function]
+                break
+        return target_lib, target_addr
+
     def resolve_functions(self, library, do_add=False):
         if isinstance(library, str):
             name = library
@@ -502,22 +531,13 @@ class LibraryStore(BaseStore):
             found = self._find_imported_function(function, library, add=do_add)
 
             if not found:
-                # Hack: search function in all libraries in the store
-                for other_lib in self.get_library_objects():
-                    if function in other_lib.exported_names:
-                        target = other_lib.exported_names
-                    else:
-                        to_search = [x.split('@@')[0] for x in other_lib.exported_names]
-                        if function in to_search:
-                            target = {x.split('@@')[0] : val for x, val in other_lib.exported_names.items()}
-                        else:
-                            continue
-                    target_name = other_lib.fullname
-                    target_addr = target[function]
-                    library.imports[function] = target_name
-                    other_lib.add_export_user(target_addr, library.fullname)
+                target_lib, target_addr = self._hard_search_for_symbol(library,
+                                                                       function)
+                if target_lib:
+                    library.imports[function] = target_lib.fullname
+                    target_lib.add_export_user(target_addr, library.fullname)
                     logging.info('hard search for %s, found at %s:%x', function,
-                                 target_name, target_addr)
+                                 target_lib.fullname, target_addr)
                     found = True
 
             if not found:
@@ -810,6 +830,8 @@ class LibraryStore(BaseStore):
                 dump_dict_with_set_value(lib_dict, content, "object_users")
                 dump_ordered_dict_as_list(lib_dict, content, "reloc_to_local")
                 dump_ordered_dict_as_list(lib_dict, content, "reloc_to_exported")
+                dump_dict_with_set_value(lib_dict, content, "dlsym_refs")
+                dump_dict_with_set_value(lib_dict, content, "dlopen_refs")
                 lib_dict["init_functions"] = content.init_functions
                 lib_dict["fini_functions"] = content.fini_functions
 
@@ -888,6 +910,8 @@ class LibraryStore(BaseStore):
                     load_dict_with_set_values(content, library, "object_users", int)
                     load_ordered_dict_from_list(content, library, "reloc_to_local")
                     load_ordered_dict_from_list(content, library, "reloc_to_exported")
+                    load_dict_with_set_values(content, library, "dlsym_refs", int)
+                    load_dict_with_set_values(content, library, "dlopen_refs", int)
                     library.init_functions = content.get("init_functions", [])
                     library.fini_functions = content.get("fini_functions", [])
 
-- 
GitLab