From 6140d8d228b1d428c1b13baca6581ffbd6a04e8a Mon Sep 17 00:00:00 2001
From: Filip Johnsson <filjo653@student.liu.se>
Date: Mon, 25 Nov 2024 16:19:47 +0100
Subject: [PATCH] lab3 working

---
 l3/TM-Lab3.ipynb | 222 +++++++++++++++++++++++++++++++++++++++++++----
 1 file changed, 205 insertions(+), 17 deletions(-)

diff --git a/l3/TM-Lab3.ipynb b/l3/TM-Lab3.ipynb
index b14e058..1e0074f 100644
--- a/l3/TM-Lab3.ipynb
+++ b/l3/TM-Lab3.ipynb
@@ -16727,7 +16727,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 58,
+   "execution_count": 62,
    "metadata": {
     "deletable": false,
     "nbgrader": {
@@ -16784,7 +16784,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 61,
+   "execution_count": 63,
    "metadata": {
     "deletable": false,
     "editable": false,
@@ -16803,7 +16803,27 @@
      "solution"
     ]
    },
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Precision: 0.668, Recall: 0.610, F1: 0.638\n"
+     ]
+    },
+    {
+     "data": {
+      "text/html": [
+       "<div class=\"alert alert-success\"><strong>Checks have passed!</strong></div>"
+      ],
+      "text/plain": [
+       "<IPython.core.display.HTML object>"
+      ]
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    }
+   ],
    "source": [
     "scores = evaluation_scores(dev_gold_mentions, set(most_probable_method(df_dev_pred, df_kb)))\n",
     "print_evaluation_scores(scores)\n",
@@ -16831,7 +16851,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 64,
    "metadata": {
     "deletable": false,
     "editable": false,
@@ -16861,9 +16881,91 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 65,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "text/html": [
+       "<div>\n",
+       "<style scoped>\n",
+       "    .dataframe tbody tr th:only-of-type {\n",
+       "        vertical-align: middle;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe tbody tr th {\n",
+       "        vertical-align: top;\n",
+       "    }\n",
+       "\n",
+       "    .dataframe thead th {\n",
+       "        text-align: right;\n",
+       "    }\n",
+       "</style>\n",
+       "<table border=\"1\" class=\"dataframe\">\n",
+       "  <thead>\n",
+       "    <tr style=\"text-align: right;\">\n",
+       "      <th></th>\n",
+       "      <th>mention</th>\n",
+       "      <th>entity</th>\n",
+       "      <th>context</th>\n",
+       "    </tr>\n",
+       "  </thead>\n",
+       "  <tbody>\n",
+       "    <tr>\n",
+       "      <th>0</th>\n",
+       "      <td>1970</td>\n",
+       "      <td>UEFA_Champions_League</td>\n",
+       "      <td>Cup twice the first in @ and the second in 1983</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>1</th>\n",
+       "      <td>1970</td>\n",
+       "      <td>FIFA_World_Cup</td>\n",
+       "      <td>America 1975 and during the @ and 1978 World C...</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>2</th>\n",
+       "      <td>1990 World Cup</td>\n",
+       "      <td>1990_FIFA_World_Cup</td>\n",
+       "      <td>Manolo represented Spain at the @</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>3</th>\n",
+       "      <td>1990 World Cup</td>\n",
+       "      <td>1990_FIFA_World_Cup</td>\n",
+       "      <td>Hašek represented Czechoslovakia at the @ and ...</td>\n",
+       "    </tr>\n",
+       "    <tr>\n",
+       "      <th>4</th>\n",
+       "      <td>1990 World Cup</td>\n",
+       "      <td>1990_FIFA_World_Cup</td>\n",
+       "      <td>renovations in 1989 for the @ The present capa...</td>\n",
+       "    </tr>\n",
+       "  </tbody>\n",
+       "</table>\n",
+       "</div>"
+      ],
+      "text/plain": [
+       "          mention                 entity  \\\n",
+       "0            1970  UEFA_Champions_League   \n",
+       "1            1970         FIFA_World_Cup   \n",
+       "2  1990 World Cup    1990_FIFA_World_Cup   \n",
+       "3  1990 World Cup    1990_FIFA_World_Cup   \n",
+       "4  1990 World Cup    1990_FIFA_World_Cup   \n",
+       "\n",
+       "                                             context  \n",
+       "0    Cup twice the first in @ and the second in 1983  \n",
+       "1  America 1975 and during the @ and 1978 World C...  \n",
+       "2                 Manolo represented Spain at the @   \n",
+       "3  Hašek represented Czechoslovakia at the @ and ...  \n",
+       "4  renovations in 1989 for the @ The present capa...  "
+      ]
+     },
+     "execution_count": 65,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "df_contexts.head()"
    ]
@@ -16879,9 +16981,31 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 66,
    "metadata": {},
-   "outputs": [],
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "41465    Nebraska Concealed Handgun Permit In @ municip...\n",
+       "41466    Lazlo restaurants are located in @ and Omaha C...\n",
+       "41467    California Washington Overland Park Kansas @ N...\n",
+       "41468    City Missouri Omaha Nebraska and @ Nebraska It...\n",
+       "41469    by Sandhills Publishing Company in @ Nebraska USA\n",
+       "                               ...                        \n",
+       "41609                                      @ Leyton Orient\n",
+       "41610                    English division three Swansea @ \n",
+       "41611    league membership narrowly edging out @ on goa...\n",
+       "41612                                          @ Cambridge\n",
+       "41613                                                   @ \n",
+       "Name: context, Length: 149, dtype: object"
+      ]
+     },
+     "execution_count": 66,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
    "source": [
     "df_contexts.context[df_contexts.mention == 'Lincoln']"
    ]
@@ -16901,7 +17025,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 105,
    "metadata": {
     "deletable": false,
     "nbgrader": {
@@ -16920,6 +17044,10 @@
    },
    "outputs": [],
    "source": [
+    "from sklearn.naive_bayes import MultinomialNB\n",
+    "from sklearn.pipeline import Pipeline\n",
+    "from sklearn.feature_extraction.text import CountVectorizer\n",
+    "\n",
     "def build_entity_classifiers(df_kb, df_contexts):\n",
     "    \"\"\"Build Naive Bayes classifiers for entity prediction.\n",
     "\n",
@@ -16932,13 +17060,45 @@
     "        classifiers trained to predict the correct entity, given the textual\n",
     "        context of the mention (as described in detail above).\n",
     "    \"\"\"\n",
-    "    # YOUR CODE HERE\n",
-    "    raise NotImplementedError()"
+    "    # === Get X and Y training data ===\n",
+    "\n",
+    "    mention_training_data_dict = {}\n",
+    "\n",
+    "    for row in df_contexts.itertuples():\n",
+    "        mention = row[1]\n",
+    "        entity = row[2]\n",
+    "        context = row[3]\n",
+    "\n",
+    "        if mention not in mention_training_data_dict:\n",
+    "            mention_training_data_dict[mention] = {'trainX': [], 'trainY': []}\n",
+    "        \n",
+    "        mention_training_data_dict[mention]['trainX'].append(context)\n",
+    "        mention_training_data_dict[mention]['trainY'].append(entity)\n",
+    "    \n",
+    "    # === Create pipelines ===\n",
+    "\n",
+    "    mention_classifier_dict = {}\n",
+    "\n",
+    "    for mention, training_data in mention_training_data_dict.items():\n",
+    "        probs = df_kb.loc[df_kb.mention == mention].sort_values(\"entity\")\n",
+    "        class_prior = list(probs[\"prob\"])\n",
+    "\n",
+    "        pipeline = Pipeline(steps=[\n",
+    "            ('vectorizer', CountVectorizer()),\n",
+    "            ('classifier', MultinomialNB(class_prior=class_prior)),\n",
+    "        ])\n",
+    "        trainX = np.array(training_data['trainX'])\n",
+    "        trainY = np.array(training_data['trainY'])\n",
+    "        model = pipeline.fit(trainX, trainY)\n",
+    "\n",
+    "        mention_classifier_dict[mention] = model\n",
+    "    \n",
+    "    return mention_classifier_dict"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 106,
    "metadata": {
     "deletable": false,
     "nbgrader": {
@@ -16972,8 +17132,28 @@
     "        quadruples consisting of the sentence id, start position, end\n",
     "        position and the predicted entity label of each span.\n",
     "    \"\"\"\n",
-    "    # YOUR CODE HERE\n",
-    "    raise NotImplementedError()"
+    "    for row in df.itertuples():\n",
+    "        sentence_id = row[1]\n",
+    "        sentence = row[2]\n",
+    "        start = row[3]\n",
+    "        end = row[4]\n",
+    "\n",
+    "        words = sentence.split()\n",
+    "\n",
+    "        mention = ' '.join(words[int(start):int(end)])\n",
+    "        context = ' '.join(words[:int(start)]+['@']+words[int(end):])\n",
+    "        \n",
+    "        if mention in classifiers:\n",
+    "            classifier = classifiers[mention]\n",
+    "            predicted_entity = classifier.predict([context])[0]\n",
+    "        \n",
+    "        else:\n",
+    "            entity = df_kb.loc[df_kb.mention == mention].entity\n",
+    "            predicted_entity = entity.values[0] if len(entity.values) > 0 else \"--NME--\"\n",
+    "\n",
+    "        yield sentence_id, start, end, predicted_entity\n",
+    "\n",
+    "    "
    ]
   },
   {
@@ -16987,7 +17167,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 107,
    "metadata": {
     "deletable": false,
     "editable": false,
@@ -17018,7 +17198,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 108,
    "metadata": {
     "deletable": false,
     "editable": false,
@@ -17037,7 +17217,15 @@
      "solution"
     ]
    },
-   "outputs": [],
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Precision: 0.686, Recall: 0.627, F1: 0.655\n"
+     ]
+    }
+   ],
    "source": [
     "scores = evaluation_scores(dev_gold_mentions, dev_pred_dict_mentions)\n",
     "print_evaluation_scores(scores)"
-- 
GitLab