From abc9d0f686bf3f6c4d416bd5efca081eaffea0be Mon Sep 17 00:00:00 2001
From: David Norell <davno443@student.liu.se>
Date: Thu, 12 Jan 2023 13:20:13 +0000
Subject: [PATCH] Replace TM_Project.ipynb

---
 TM_Project.ipynb | 191 +++++++++++++++++++++++++++++++----------------
 1 file changed, 126 insertions(+), 65 deletions(-)

diff --git a/TM_Project.ipynb b/TM_Project.ipynb
index a825096..21407f2 100644
--- a/TM_Project.ipynb
+++ b/TM_Project.ipynb
@@ -26,13 +26,13 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 2,
+      "execution_count": 1,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
         "id": "01XiZPzPKXYa",
-        "outputId": "255e2006-4a0a-4a3b-b79e-538b309c61ad"
+        "outputId": "94120340-3561-4c52-e105-49c66ba50f31"
       },
       "outputs": [
         {
@@ -63,13 +63,13 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 3,
+      "execution_count": 2,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
         "id": "T9ZLMs_IKzap",
-        "outputId": "87878a5e-f10b-423f-ece3-6be4cb6ff286"
+        "outputId": "aafb6570-5170-4b8a-dce9-1de24a402ba3"
       },
       "outputs": [
         {
@@ -109,14 +109,14 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 4,
+      "execution_count": 3,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/",
           "height": 612
         },
         "id": "wJ77wcuhMr_8",
-        "outputId": "710566bc-d672-4919-c404-2a59edb022b1"
+        "outputId": "d3c1600e-dc50-4958-94d1-d754cf414130"
       },
       "outputs": [
         {
@@ -169,7 +169,7 @@
             ],
             "text/html": [
               "\n",
-              "  <div id=\"df-5ce28d08-ca3d-4748-9bac-4527eec1ac1a\">\n",
+              "  <div id=\"df-b064dcb5-b1de-4b25-821e-a0225f288f09\">\n",
               "    <div class=\"colab-df-container\">\n",
               "      <div>\n",
               "<style scoped>\n",
@@ -337,7 +337,7 @@
               "</table>\n",
               "<p>5 rows × 31 columns</p>\n",
               "</div>\n",
-              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-5ce28d08-ca3d-4748-9bac-4527eec1ac1a')\"\n",
+              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-b064dcb5-b1de-4b25-821e-a0225f288f09')\"\n",
               "              title=\"Convert this dataframe to an interactive table.\"\n",
               "              style=\"display:none;\">\n",
               "        \n",
@@ -388,12 +388,12 @@
               "\n",
               "      <script>\n",
               "        const buttonEl =\n",
-              "          document.querySelector('#df-5ce28d08-ca3d-4748-9bac-4527eec1ac1a button.colab-df-convert');\n",
+              "          document.querySelector('#df-b064dcb5-b1de-4b25-821e-a0225f288f09 button.colab-df-convert');\n",
               "        buttonEl.style.display =\n",
               "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
               "\n",
               "        async function convertToInteractive(key) {\n",
-              "          const element = document.querySelector('#df-5ce28d08-ca3d-4748-9bac-4527eec1ac1a');\n",
+              "          const element = document.querySelector('#df-b064dcb5-b1de-4b25-821e-a0225f288f09');\n",
               "          const dataTable =\n",
               "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
               "                                                     [key], {});\n",
@@ -416,7 +416,7 @@
             ]
           },
           "metadata": {},
-          "execution_count": 4
+          "execution_count": 3
         }
       ],
       "source": [
@@ -434,7 +434,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 5,
+      "execution_count": 4,
       "metadata": {
         "id": "lwcbT8ijd9-c"
       },
@@ -470,13 +470,13 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 6,
+      "execution_count": 5,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
         "id": "IVyaai91ddBO",
-        "outputId": "7da0f151-ddba-4155-ff52-bb4a1854e59d"
+        "outputId": "1dcbaf82-7d34-4a3f-d3ac-b513c3e11f57"
       },
       "outputs": [
         {
@@ -491,7 +491,7 @@
             ]
           },
           "metadata": {},
-          "execution_count": 6
+          "execution_count": 5
         }
       ],
       "source": [
@@ -510,13 +510,13 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 7,
+      "execution_count": 6,
       "metadata": {
         "id": "EAK29Pu6X5q9",
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
-        "outputId": "6982db9d-4804-493c-9dcf-d411eba38856"
+        "outputId": "2b3b5db9-4922-4fd7-aa44-939c20466452"
       },
       "outputs": [
         {
@@ -562,7 +562,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 18,
+      "execution_count": 7,
       "metadata": {
         "id": "Jk9OaFp1en9O"
       },
@@ -589,24 +589,24 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 19,
+      "execution_count": 8,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/"
         },
         "id": "wn4tjCkFdaaw",
-        "outputId": "48a4bead-d3e3-432e-8584-8207060e6d64"
+        "outputId": "6ff2db19-a9a2-48fc-b363-7f34119564cc"
       },
       "outputs": [
         {
           "output_type": "execute_result",
           "data": {
             "text/plain": [
-              "0.6174037089871611"
+              "0.62382176520994"
             ]
           },
           "metadata": {},
-          "execution_count": 19
+          "execution_count": 8
         }
       ],
       "source": [
@@ -631,11 +631,26 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 9,
       "metadata": {
-        "id": "b3URB7EjhUC2"
+        "id": "b3URB7EjhUC2",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "fd62484e-917c-46f7-d19d-bbca831353ca"
       },
-      "outputs": [],
+      "outputs": [
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "0.638993710691824"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 9
+        }
+      ],
       "source": [
         "pipe2 = Pipeline([('vec', TfidfVectorizer(max_features = 10000)), \n",
         "                  ('clf', OneVsRestClassifier(LogisticRegression(C = 10, n_jobs = -1)))])\n",
@@ -658,7 +673,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 10,
       "metadata": {
         "id": "GzAIvOXmnDwF"
       },
@@ -676,20 +691,36 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 11,
       "metadata": {
-        "id": "fBCTZdWJngAS"
+        "id": "fBCTZdWJngAS",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "b1035956-e263-41fc-f97a-910fe36fa950"
       },
-      "outputs": [],
+      "outputs": [
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "0.686046511627907"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 11
+        }
+      ],
       "source": [
-        "preds3 = pipe2.predict_proba(df_test['ABSTRACT'])\n",
+        "valid_preds3 = pipe2.predict_proba(df_valid['ABSTRACT'])\n",
+        "test_preds3 = pipe2.predict_proba(df_test['ABSTRACT'])\n",
         "\n",
-        "best_thresholds = get_best_thresholds(df_test[TARGET_COLS].values, preds3)\n",
+        "best_thresholds = get_best_thresholds(df_valid[TARGET_COLS].values, valid_preds3)\n",
         "\n",
         "for i, thresh in enumerate(best_thresholds):\n",
-        "  preds3[:, i] = (preds3[:, i] > thresh) * 1\n",
+        "  test_preds3[:, i] = (test_preds3[:, i] > thresh) * 1\n",
         "  \n",
-        "f1_score(df_test[TARGET_COLS], preds3, average='micro')"
+        "f1_score(df_test[TARGET_COLS], test_preds3, average='micro')"
       ]
     },
     {
@@ -704,52 +735,82 @@
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 13,
       "metadata": {
-        "id": "8cwaFlwYs8YZ"
+        "id": "8cwaFlwYs8YZ",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "63a2a79e-75e2-48f3-f97b-07eaee6e72a5"
       },
-      "outputs": [],
+      "outputs": [
+        {
+          "output_type": "stream",
+          "name": "stdout",
+          "text": [
+            "(11203, 10004) (1400, 10004) (1401, 10004)\n"
+          ]
+        }
+      ],
       "source": [
         "vec = CountVectorizer(max_features = 10000)\n",
         "_ = vec.fit(list(df_train['ABSTRACT']) + list(df_test['ABSTRACT']))\n",
         "\n",
         "train2 = np.hstack((vec.transform(df_train['ABSTRACT']).toarray(), df_train[TOPIC_COLS]))\n",
+        "valid2 = np.hstack((vec.transform(df_valid['ABSTRACT']).toarray(), df_valid[TOPIC_COLS]))\n",
         "test2 = np.hstack((vec.transform(df_test['ABSTRACT']).toarray(), df_test[TOPIC_COLS]))\n",
         "\n",
-        "print(train2.shape, test2.shape)"
+        "print(train2.shape, valid2.shape, test2.shape)"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 14,
       "metadata": {
         "id": "15FQQT9ls_Xq"
       },
       "outputs": [],
       "source": [
         "trn2 = csr_matrix(train2.astype('int16'))\n",
-        "val2 = csr_matrix(test2.astype('int16'))"
+        "val2 = csr_matrix(valid2.astype('int16'))\n",
+        "tes2 = csr_matrix(test2.astype('int16'))"
       ]
     },
     {
       "cell_type": "code",
-      "execution_count": null,
+      "execution_count": 15,
       "metadata": {
-        "id": "sgWCfox7tF0X"
+        "id": "sgWCfox7tF0X",
+        "colab": {
+          "base_uri": "https://localhost:8080/"
+        },
+        "outputId": "885c80fd-5f5a-4ff7-c608-fd829a778e9a"
       },
-      "outputs": [],
+      "outputs": [
+        {
+          "output_type": "execute_result",
+          "data": {
+            "text/plain": [
+              "0.6919053549190535"
+            ]
+          },
+          "metadata": {},
+          "execution_count": 15
+        }
+      ],
       "source": [
         "clf = OneVsRestClassifier(LogisticRegression(C = 10, n_jobs=-1))\n",
         "_  = clf.fit(train2, df_train[TARGET_COLS])\n",
         "\n",
-        "preds4 = clf.predict_proba(test2)\n",
+        "valid_preds4 = clf.predict_proba(valid2)\n",
+        "test_preds4 = clf.predict_proba(test2)\n",
         "\n",
-        "best_thresholds = get_best_thresholds(df_test[TARGET_COLS].values, preds4)\n",
+        "best_thresholds = get_best_thresholds(df_valid[TARGET_COLS].values, valid_preds4)\n",
         "\n",
         "for i, thresh in enumerate(best_thresholds):\n",
-        "  preds4[:, i] = (preds4[:, i] > thresh) * 1\n",
+        "  test_preds4[:, i] = (test_preds4[:, i] > thresh) * 1\n",
         "\n",
-        "f1_score(df_test[TARGET_COLS], preds4, average = 'micro')"
+        "f1_score(df_test[TARGET_COLS], test_preds4, average = 'micro')"
       ]
     },
     {
@@ -761,10 +822,10 @@
         "## Results\n",
         "Below are the f1-scores produced by each baseline configuration above. The f1-score is computed globally by counting the total TP, FN, and FP, according to the 'micro' parameter for the f1-score function.\n",
         "\n",
-        "* Countvectorizer: **0.622**\n",
-        "* Tfidfvectorizer: **0.637**\n",
-        "* Optimal threshold: **0.689**\n",
-        "* Combining topics: **0.714**"
+        "* Countvectorizer: **0.624**\n",
+        "* Tfidfvectorizer: **0.639**\n",
+        "* Optimal threshold: **0.686**\n",
+        "* Combining topics: **0.692**"
       ]
     },
     {
@@ -781,7 +842,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 8,
+      "execution_count": null,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/"
@@ -828,7 +889,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 9,
+      "execution_count": null,
       "metadata": {
         "id": "y4eT5t-9td2O"
       },
@@ -857,7 +918,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 10,
+      "execution_count": null,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/"
@@ -891,7 +952,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 11,
+      "execution_count": null,
       "metadata": {
         "id": "IuX8ZiPAtgXw",
         "colab": {
@@ -1025,7 +1086,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 14,
+      "execution_count": null,
       "metadata": {
         "id": "P2FIjkfouyV_"
       },
@@ -1077,7 +1138,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 12,
+      "execution_count": null,
       "metadata": {
         "colab": {
           "base_uri": "https://localhost:8080/",
@@ -1154,7 +1215,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 13,
+      "execution_count": null,
       "metadata": {
         "id": "YZqoWbnxvhJr"
       },
@@ -1176,7 +1237,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 15,
+      "execution_count": null,
       "metadata": {
         "id": "raChNaDFvoXD"
       },
@@ -1231,7 +1292,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 16,
+      "execution_count": null,
       "metadata": {
         "id": "3XgL-Jko9Wx5"
       },
@@ -1273,7 +1334,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 17,
+      "execution_count": null,
       "metadata": {
         "id": "iuzdvuVjvr_n"
       },
@@ -1332,7 +1393,7 @@
       "metadata": {
         "id": "kPFNINne8PGO"
       },
-      "execution_count": 18,
+      "execution_count": null,
       "outputs": []
     },
     {
@@ -1441,7 +1502,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 30,
+      "execution_count": null,
       "metadata": {
         "id": "Lj7B9J7W0zOD"
       },
@@ -1486,7 +1547,7 @@
     },
     {
       "cell_type": "code",
-      "execution_count": 31,
+      "execution_count": null,
       "metadata": {
         "id": "vGLrHjX62YZq",
         "colab": {
@@ -1630,7 +1691,7 @@
       "metadata": {
         "id": "6obBj-9qjZlY"
       },
-      "execution_count": 19,
+      "execution_count": null,
       "outputs": []
     },
     {
@@ -1647,7 +1708,7 @@
         "id": "OMBdm2BInez4",
         "outputId": "87b351c3-c72b-42c7-c408-fb61fa073b0a"
       },
-      "execution_count": 20,
+      "execution_count": null,
       "outputs": [
         {
           "output_type": "stream",
-- 
GitLab