From 5c68ee1becfc55219b63742bdc7741e585b4516c Mon Sep 17 00:00:00 2001
From: esaunders <esaunders@basistech.com>
Date: Fri, 27 Dec 2013 11:16:26 -0500
Subject: [PATCH] Added private method to TskImgDBPostgreSQL that looks for an
 existing module based on its name and returns the module id. This method is
 called by TskImgDBPostgreSQL::addModule() to initially check if the module
 already exists. If the module does not exist we attempt to insert. This
 insert attempt may fail if another node in a distributed environment creates
 the module entry. In that case we again fall back on the new getModuleId()
 method.

---
 .../framework/services/TskImgDBPostgreSQL.cpp | 63 ++++++++++++++-----
 .../framework/services/TskImgDBPostgreSQL.h   |  1 +
 2 files changed, 47 insertions(+), 17 deletions(-)

diff --git a/framework/tsk/framework/services/TskImgDBPostgreSQL.cpp b/framework/tsk/framework/services/TskImgDBPostgreSQL.cpp
index e90874d23..20616faa0 100755
--- a/framework/tsk/framework/services/TskImgDBPostgreSQL.cpp
+++ b/framework/tsk/framework/services/TskImgDBPostgreSQL.cpp
@@ -2860,6 +2860,37 @@ int TskImgDBPostgreSQL::getFileTypeRecords(std::string& stmt, std::list<TskFileT
     return rc;
 }
 
+/**
+ *
+ */
+int TskImgDBPostgreSQL::getModuleId(const std::string& name, int & moduleId) const
+{
+    stringstream stmt;
+
+    stmt << "SELECT module_id FROM modules WHERE name = " << m_dbConnection->quote(name);
+
+    try 
+    {
+        pqxx::read_transaction trans(*m_dbConnection);
+        result R = trans.exec(stmt);
+
+        if (R.size() == 1)
+        {
+            R[0][0].to(moduleId);
+        }
+    }
+    catch(exception& e)
+    {
+        std::stringstream errorMsg;
+        errorMsg << "TskDBPostgreSQL::getModuleId - Error querying modules table: "
+            << e.what();
+        LOGERROR(errorMsg.str());
+        return -1;
+    }
+
+    return 0;
+}
+
 /**
  * Insert the Module record, if module name does not already exist in modules table.
  * Returns Module Id associated with the Module record.
@@ -2873,37 +2904,35 @@ int TskImgDBPostgreSQL::addModule(const std::string& name, const std::string& de
     if (!initialized())
         return 0;
 
-    stringstream stmt;
+    moduleId = 0;
 
-    stmt << "SELECT module_id FROM modules WHERE name = " << m_dbConnection->quote(name);
+    if (getModuleId(name, moduleId) == 0 && moduleId > 0)
+        return 0;
 
     try 
     {
-        work W(*m_dbConnection);
-        result R = W.exec(stmt);
+        stringstream stmt;
 
-        if (R.size() == 1)
-        {
-            // Already exists, return module_id
-            R[0][0].to(moduleId);
-            return 0;
-        }
-
-        // Insert a new one
-        stmt.str("");
+        work W(*m_dbConnection);
         stmt << "INSERT INTO modules (module_id, name, description) VALUES (DEFAULT, " << m_dbConnection->quote(name) << ", " << m_dbConnection->quote(description) << ")"
              << " RETURNING module_id";
 
-        R = W.exec(stmt);
+        pqxx::result R = W.exec(stmt);
 
         // Get the newly assigned module id
         R[0][0].to(moduleId);
         W.commit();
-    } 
+    }
+    catch (pqxx::unique_violation&)
+    {
+        // The module may have been added between our initial call
+        // to getModuleId() and the subsequent INSERT attempt.
+        getModuleId(name, moduleId);
+    }
     catch (const exception &e)
     {
-        std::wstringstream errorMsg;
-        errorMsg << L"TskDBPostgreSQL::addModule - Error inserting into modules table: "
+        std::stringstream errorMsg;
+        errorMsg << "TskDBPostgreSQL::addModule - Error inserting into modules table: "
             << e.what();
         LOGERROR(errorMsg.str());
         return -1;
diff --git a/framework/tsk/framework/services/TskImgDBPostgreSQL.h b/framework/tsk/framework/services/TskImgDBPostgreSQL.h
index edd8ba3e6..c57ecb463 100755
--- a/framework/tsk/framework/services/TskImgDBPostgreSQL.h
+++ b/framework/tsk/framework/services/TskImgDBPostgreSQL.h
@@ -161,6 +161,7 @@ class TskImgDBPostgreSQL : public TskImgDB
     int getFileTypeRecords(std::string& stmt, std::list<TskFileTypeRecord>& fileTypeInfoList) const;
     vector<TskBlackboardArtifact> getArtifactsHelper(uint64_t file_id, int artifactTypeID, string artifactTypeName);
     void getCarvedFileInfo(const std::string& stmt, std::map<uint64_t, std::string>& results) const;
+    int getModuleId(const std::string& name, int& moduleId) const;
 
     /**
      * A helper function for getUniqueCarvedFilesInfo() that executes a very specific SQL SELECT statement 
-- 
GitLab