From abc9d0f686bf3f6c4d416bd5efca081eaffea0be Mon Sep 17 00:00:00 2001 From: David Norell <davno443@student.liu.se> Date: Thu, 12 Jan 2023 13:20:13 +0000 Subject: [PATCH] Replace TM_Project.ipynb --- TM_Project.ipynb | 191 +++++++++++++++++++++++++++++++---------------- 1 file changed, 126 insertions(+), 65 deletions(-) diff --git a/TM_Project.ipynb b/TM_Project.ipynb index a825096..21407f2 100644 --- a/TM_Project.ipynb +++ b/TM_Project.ipynb @@ -26,13 +26,13 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "01XiZPzPKXYa", - "outputId": "255e2006-4a0a-4a3b-b79e-538b309c61ad" + "outputId": "94120340-3561-4c52-e105-49c66ba50f31" }, "outputs": [ { @@ -63,13 +63,13 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "T9ZLMs_IKzap", - "outputId": "87878a5e-f10b-423f-ece3-6be4cb6ff286" + "outputId": "aafb6570-5170-4b8a-dce9-1de24a402ba3" }, "outputs": [ { @@ -109,14 +109,14 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 612 }, "id": "wJ77wcuhMr_8", - "outputId": "710566bc-d672-4919-c404-2a59edb022b1" + "outputId": "d3c1600e-dc50-4958-94d1-d754cf414130" }, "outputs": [ { @@ -169,7 +169,7 @@ ], "text/html": [ "\n", - " <div id=\"df-5ce28d08-ca3d-4748-9bac-4527eec1ac1a\">\n", + " <div id=\"df-b064dcb5-b1de-4b25-821e-a0225f288f09\">\n", " <div class=\"colab-df-container\">\n", " <div>\n", "<style scoped>\n", @@ -337,7 +337,7 @@ "</table>\n", "<p>5 rows × 31 columns</p>\n", "</div>\n", - " <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-5ce28d08-ca3d-4748-9bac-4527eec1ac1a')\"\n", + " <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-b064dcb5-b1de-4b25-821e-a0225f288f09')\"\n", " title=\"Convert this dataframe to an interactive table.\"\n", " style=\"display:none;\">\n", " \n", @@ -388,12 +388,12 @@ "\n", " <script>\n", " const buttonEl =\n", - " document.querySelector('#df-5ce28d08-ca3d-4748-9bac-4527eec1ac1a button.colab-df-convert');\n", + " document.querySelector('#df-b064dcb5-b1de-4b25-821e-a0225f288f09 button.colab-df-convert');\n", " buttonEl.style.display =\n", " google.colab.kernel.accessAllowed ? 'block' : 'none';\n", "\n", " async function convertToInteractive(key) {\n", - " const element = document.querySelector('#df-5ce28d08-ca3d-4748-9bac-4527eec1ac1a');\n", + " const element = document.querySelector('#df-b064dcb5-b1de-4b25-821e-a0225f288f09');\n", " const dataTable =\n", " await google.colab.kernel.invokeFunction('convertToInteractive',\n", " [key], {});\n", @@ -416,7 +416,7 @@ ] }, "metadata": {}, - "execution_count": 4 + "execution_count": 3 } ], "source": [ @@ -434,7 +434,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": { "id": "lwcbT8ijd9-c" }, @@ -470,13 +470,13 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "IVyaai91ddBO", - "outputId": "7da0f151-ddba-4155-ff52-bb4a1854e59d" + "outputId": "1dcbaf82-7d34-4a3f-d3ac-b513c3e11f57" }, "outputs": [ { @@ -491,7 +491,7 @@ ] }, "metadata": {}, - "execution_count": 6 + "execution_count": 5 } ], "source": [ @@ -510,13 +510,13 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": { "id": "EAK29Pu6X5q9", "colab": { "base_uri": "https://localhost:8080/" }, - "outputId": "6982db9d-4804-493c-9dcf-d411eba38856" + "outputId": "2b3b5db9-4922-4fd7-aa44-939c20466452" }, "outputs": [ { @@ -562,7 +562,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 7, "metadata": { "id": "Jk9OaFp1en9O" }, @@ -589,24 +589,24 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "wn4tjCkFdaaw", - "outputId": "48a4bead-d3e3-432e-8584-8207060e6d64" + "outputId": "6ff2db19-a9a2-48fc-b363-7f34119564cc" }, "outputs": [ { "output_type": "execute_result", "data": { "text/plain": [ - "0.6174037089871611" + "0.62382176520994" ] }, "metadata": {}, - "execution_count": 19 + "execution_count": 8 } ], "source": [ @@ -631,11 +631,26 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 9, "metadata": { - "id": "b3URB7EjhUC2" + "id": "b3URB7EjhUC2", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "fd62484e-917c-46f7-d19d-bbca831353ca" }, - "outputs": [], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "0.638993710691824" + ] + }, + "metadata": {}, + "execution_count": 9 + } + ], "source": [ "pipe2 = Pipeline([('vec', TfidfVectorizer(max_features = 10000)), \n", " ('clf', OneVsRestClassifier(LogisticRegression(C = 10, n_jobs = -1)))])\n", @@ -658,7 +673,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 10, "metadata": { "id": "GzAIvOXmnDwF" }, @@ -676,20 +691,36 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 11, "metadata": { - "id": "fBCTZdWJngAS" + "id": "fBCTZdWJngAS", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "b1035956-e263-41fc-f97a-910fe36fa950" }, - "outputs": [], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "0.686046511627907" + ] + }, + "metadata": {}, + "execution_count": 11 + } + ], "source": [ - "preds3 = pipe2.predict_proba(df_test['ABSTRACT'])\n", + "valid_preds3 = pipe2.predict_proba(df_valid['ABSTRACT'])\n", + "test_preds3 = pipe2.predict_proba(df_test['ABSTRACT'])\n", "\n", - "best_thresholds = get_best_thresholds(df_test[TARGET_COLS].values, preds3)\n", + "best_thresholds = get_best_thresholds(df_valid[TARGET_COLS].values, valid_preds3)\n", "\n", "for i, thresh in enumerate(best_thresholds):\n", - " preds3[:, i] = (preds3[:, i] > thresh) * 1\n", + " test_preds3[:, i] = (test_preds3[:, i] > thresh) * 1\n", " \n", - "f1_score(df_test[TARGET_COLS], preds3, average='micro')" + "f1_score(df_test[TARGET_COLS], test_preds3, average='micro')" ] }, { @@ -704,52 +735,82 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 13, "metadata": { - "id": "8cwaFlwYs8YZ" + "id": "8cwaFlwYs8YZ", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "63a2a79e-75e2-48f3-f97b-07eaee6e72a5" }, - "outputs": [], + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "(11203, 10004) (1400, 10004) (1401, 10004)\n" + ] + } + ], "source": [ "vec = CountVectorizer(max_features = 10000)\n", "_ = vec.fit(list(df_train['ABSTRACT']) + list(df_test['ABSTRACT']))\n", "\n", "train2 = np.hstack((vec.transform(df_train['ABSTRACT']).toarray(), df_train[TOPIC_COLS]))\n", + "valid2 = np.hstack((vec.transform(df_valid['ABSTRACT']).toarray(), df_valid[TOPIC_COLS]))\n", "test2 = np.hstack((vec.transform(df_test['ABSTRACT']).toarray(), df_test[TOPIC_COLS]))\n", "\n", - "print(train2.shape, test2.shape)" + "print(train2.shape, valid2.shape, test2.shape)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 14, "metadata": { "id": "15FQQT9ls_Xq" }, "outputs": [], "source": [ "trn2 = csr_matrix(train2.astype('int16'))\n", - "val2 = csr_matrix(test2.astype('int16'))" + "val2 = csr_matrix(valid2.astype('int16'))\n", + "tes2 = csr_matrix(test2.astype('int16'))" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 15, "metadata": { - "id": "sgWCfox7tF0X" + "id": "sgWCfox7tF0X", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "885c80fd-5f5a-4ff7-c608-fd829a778e9a" }, - "outputs": [], + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "0.6919053549190535" + ] + }, + "metadata": {}, + "execution_count": 15 + } + ], "source": [ "clf = OneVsRestClassifier(LogisticRegression(C = 10, n_jobs=-1))\n", "_ = clf.fit(train2, df_train[TARGET_COLS])\n", "\n", - "preds4 = clf.predict_proba(test2)\n", + "valid_preds4 = clf.predict_proba(valid2)\n", + "test_preds4 = clf.predict_proba(test2)\n", "\n", - "best_thresholds = get_best_thresholds(df_test[TARGET_COLS].values, preds4)\n", + "best_thresholds = get_best_thresholds(df_valid[TARGET_COLS].values, valid_preds4)\n", "\n", "for i, thresh in enumerate(best_thresholds):\n", - " preds4[:, i] = (preds4[:, i] > thresh) * 1\n", + " test_preds4[:, i] = (test_preds4[:, i] > thresh) * 1\n", "\n", - "f1_score(df_test[TARGET_COLS], preds4, average = 'micro')" + "f1_score(df_test[TARGET_COLS], test_preds4, average = 'micro')" ] }, { @@ -761,10 +822,10 @@ "## Results\n", "Below are the f1-scores produced by each baseline configuration above. The f1-score is computed globally by counting the total TP, FN, and FP, according to the 'micro' parameter for the f1-score function.\n", "\n", - "* Countvectorizer: **0.622**\n", - "* Tfidfvectorizer: **0.637**\n", - "* Optimal threshold: **0.689**\n", - "* Combining topics: **0.714**" + "* Countvectorizer: **0.624**\n", + "* Tfidfvectorizer: **0.639**\n", + "* Optimal threshold: **0.686**\n", + "* Combining topics: **0.692**" ] }, { @@ -781,7 +842,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -828,7 +889,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": { "id": "y4eT5t-9td2O" }, @@ -857,7 +918,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -891,7 +952,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": { "id": "IuX8ZiPAtgXw", "colab": { @@ -1025,7 +1086,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": { "id": "P2FIjkfouyV_" }, @@ -1077,7 +1138,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -1154,7 +1215,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": { "id": "YZqoWbnxvhJr" }, @@ -1176,7 +1237,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": { "id": "raChNaDFvoXD" }, @@ -1231,7 +1292,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": { "id": "3XgL-Jko9Wx5" }, @@ -1273,7 +1334,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": { "id": "iuzdvuVjvr_n" }, @@ -1332,7 +1393,7 @@ "metadata": { "id": "kPFNINne8PGO" }, - "execution_count": 18, + "execution_count": null, "outputs": [] }, { @@ -1441,7 +1502,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "metadata": { "id": "Lj7B9J7W0zOD" }, @@ -1486,7 +1547,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": null, "metadata": { "id": "vGLrHjX62YZq", "colab": { @@ -1630,7 +1691,7 @@ "metadata": { "id": "6obBj-9qjZlY" }, - "execution_count": 19, + "execution_count": null, "outputs": [] }, { @@ -1647,7 +1708,7 @@ "id": "OMBdm2BInez4", "outputId": "87b351c3-c72b-42c7-c408-fb61fa073b0a" }, - "execution_count": 20, + "execution_count": null, "outputs": [ { "output_type": "stream", -- GitLab