From 6140d8d228b1d428c1b13baca6581ffbd6a04e8a Mon Sep 17 00:00:00 2001 From: Filip Johnsson <filjo653@student.liu.se> Date: Mon, 25 Nov 2024 16:19:47 +0100 Subject: [PATCH] lab3 working --- l3/TM-Lab3.ipynb | 222 +++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 205 insertions(+), 17 deletions(-) diff --git a/l3/TM-Lab3.ipynb b/l3/TM-Lab3.ipynb index b14e058..1e0074f 100644 --- a/l3/TM-Lab3.ipynb +++ b/l3/TM-Lab3.ipynb @@ -16727,7 +16727,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": 62, "metadata": { "deletable": false, "nbgrader": { @@ -16784,7 +16784,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 63, "metadata": { "deletable": false, "editable": false, @@ -16803,7 +16803,27 @@ "solution" ] }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Precision: 0.668, Recall: 0.610, F1: 0.638\n" + ] + }, + { + "data": { + "text/html": [ + "<div class=\"alert alert-success\"><strong>Checks have passed!</strong></div>" + ], + "text/plain": [ + "<IPython.core.display.HTML object>" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ "scores = evaluation_scores(dev_gold_mentions, set(most_probable_method(df_dev_pred, df_kb)))\n", "print_evaluation_scores(scores)\n", @@ -16831,7 +16851,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 64, "metadata": { "deletable": false, "editable": false, @@ -16861,9 +16881,91 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 65, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>mention</th>\n", + " <th>entity</th>\n", + " <th>context</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>1970</td>\n", + " <td>UEFA_Champions_League</td>\n", + " <td>Cup twice the first in @ and the second in 1983</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>1970</td>\n", + " <td>FIFA_World_Cup</td>\n", + " <td>America 1975 and during the @ and 1978 World C...</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>1990 World Cup</td>\n", + " <td>1990_FIFA_World_Cup</td>\n", + " <td>Manolo represented Spain at the @</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>1990 World Cup</td>\n", + " <td>1990_FIFA_World_Cup</td>\n", + " <td>Hašek represented Czechoslovakia at the @ and ...</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>1990 World Cup</td>\n", + " <td>1990_FIFA_World_Cup</td>\n", + " <td>renovations in 1989 for the @ The present capa...</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " mention entity \\\n", + "0 1970 UEFA_Champions_League \n", + "1 1970 FIFA_World_Cup \n", + "2 1990 World Cup 1990_FIFA_World_Cup \n", + "3 1990 World Cup 1990_FIFA_World_Cup \n", + "4 1990 World Cup 1990_FIFA_World_Cup \n", + "\n", + " context \n", + "0 Cup twice the first in @ and the second in 1983 \n", + "1 America 1975 and during the @ and 1978 World C... \n", + "2 Manolo represented Spain at the @ \n", + "3 Hašek represented Czechoslovakia at the @ and ... \n", + "4 renovations in 1989 for the @ The present capa... " + ] + }, + "execution_count": 65, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "df_contexts.head()" ] @@ -16879,9 +16981,31 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 66, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/plain": [ + "41465 Nebraska Concealed Handgun Permit In @ municip...\n", + "41466 Lazlo restaurants are located in @ and Omaha C...\n", + "41467 California Washington Overland Park Kansas @ N...\n", + "41468 City Missouri Omaha Nebraska and @ Nebraska It...\n", + "41469 by Sandhills Publishing Company in @ Nebraska USA\n", + " ... \n", + "41609 @ Leyton Orient\n", + "41610 English division three Swansea @ \n", + "41611 league membership narrowly edging out @ on goa...\n", + "41612 @ Cambridge\n", + "41613 @ \n", + "Name: context, Length: 149, dtype: object" + ] + }, + "execution_count": 66, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ "df_contexts.context[df_contexts.mention == 'Lincoln']" ] @@ -16901,7 +17025,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 105, "metadata": { "deletable": false, "nbgrader": { @@ -16920,6 +17044,10 @@ }, "outputs": [], "source": [ + "from sklearn.naive_bayes import MultinomialNB\n", + "from sklearn.pipeline import Pipeline\n", + "from sklearn.feature_extraction.text import CountVectorizer\n", + "\n", "def build_entity_classifiers(df_kb, df_contexts):\n", " \"\"\"Build Naive Bayes classifiers for entity prediction.\n", "\n", @@ -16932,13 +17060,45 @@ " classifiers trained to predict the correct entity, given the textual\n", " context of the mention (as described in detail above).\n", " \"\"\"\n", - " # YOUR CODE HERE\n", - " raise NotImplementedError()" + " # === Get X and Y training data ===\n", + "\n", + " mention_training_data_dict = {}\n", + "\n", + " for row in df_contexts.itertuples():\n", + " mention = row[1]\n", + " entity = row[2]\n", + " context = row[3]\n", + "\n", + " if mention not in mention_training_data_dict:\n", + " mention_training_data_dict[mention] = {'trainX': [], 'trainY': []}\n", + " \n", + " mention_training_data_dict[mention]['trainX'].append(context)\n", + " mention_training_data_dict[mention]['trainY'].append(entity)\n", + " \n", + " # === Create pipelines ===\n", + "\n", + " mention_classifier_dict = {}\n", + "\n", + " for mention, training_data in mention_training_data_dict.items():\n", + " probs = df_kb.loc[df_kb.mention == mention].sort_values(\"entity\")\n", + " class_prior = list(probs[\"prob\"])\n", + "\n", + " pipeline = Pipeline(steps=[\n", + " ('vectorizer', CountVectorizer()),\n", + " ('classifier', MultinomialNB(class_prior=class_prior)),\n", + " ])\n", + " trainX = np.array(training_data['trainX'])\n", + " trainY = np.array(training_data['trainY'])\n", + " model = pipeline.fit(trainX, trainY)\n", + "\n", + " mention_classifier_dict[mention] = model\n", + " \n", + " return mention_classifier_dict" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 106, "metadata": { "deletable": false, "nbgrader": { @@ -16972,8 +17132,28 @@ " quadruples consisting of the sentence id, start position, end\n", " position and the predicted entity label of each span.\n", " \"\"\"\n", - " # YOUR CODE HERE\n", - " raise NotImplementedError()" + " for row in df.itertuples():\n", + " sentence_id = row[1]\n", + " sentence = row[2]\n", + " start = row[3]\n", + " end = row[4]\n", + "\n", + " words = sentence.split()\n", + "\n", + " mention = ' '.join(words[int(start):int(end)])\n", + " context = ' '.join(words[:int(start)]+['@']+words[int(end):])\n", + " \n", + " if mention in classifiers:\n", + " classifier = classifiers[mention]\n", + " predicted_entity = classifier.predict([context])[0]\n", + " \n", + " else:\n", + " entity = df_kb.loc[df_kb.mention == mention].entity\n", + " predicted_entity = entity.values[0] if len(entity.values) > 0 else \"--NME--\"\n", + "\n", + " yield sentence_id, start, end, predicted_entity\n", + "\n", + " " ] }, { @@ -16987,7 +17167,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 107, "metadata": { "deletable": false, "editable": false, @@ -17018,7 +17198,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 108, "metadata": { "deletable": false, "editable": false, @@ -17037,7 +17217,15 @@ "solution" ] }, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Precision: 0.686, Recall: 0.627, F1: 0.655\n" + ] + } + ], "source": [ "scores = evaluation_scores(dev_gold_mentions, dev_pred_dict_mentions)\n", "print_evaluation_scores(scores)" -- GitLab