diff --git a/code.ipynb b/code.ipynb index 71797ae083d89865d53e162031176f929c5629a7..d50b6f2ed76bf985c5d7e86329287591ef8bfa4d 100644 --- a/code.ipynb +++ b/code.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 314, + "execution_count": 485, "metadata": {}, "outputs": [], "source": [ @@ -30,12 +30,14 @@ "from sklearn.gaussian_process.kernels import RBF\n", "from sklearn.linear_model import LogisticRegression\n", "from sklearn.pipeline import make_pipeline\n", - "from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score" + "from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score\n", + "\n", + "pd.options.mode.copy_on_write = True" ] }, { "cell_type": "code", - "execution_count": 279, + "execution_count": 439, "metadata": {}, "outputs": [ { @@ -360,7 +362,7 @@ }, { "cell_type": "code", - "execution_count": 280, + "execution_count": 440, "metadata": {}, "outputs": [], "source": [ @@ -387,7 +389,7 @@ }, { "cell_type": "code", - "execution_count": 281, + "execution_count": 441, "metadata": {}, "outputs": [], "source": [ @@ -482,7 +484,7 @@ }, { "cell_type": "code", - "execution_count": 282, + "execution_count": 442, "metadata": {}, "outputs": [ { @@ -512,7 +514,7 @@ }, { "cell_type": "code", - "execution_count": 283, + "execution_count": 443, "metadata": {}, "outputs": [], "source": [ @@ -524,7 +526,7 @@ }, { "cell_type": "code", - "execution_count": 284, + "execution_count": 444, "metadata": {}, "outputs": [], "source": [ @@ -537,7 +539,7 @@ }, { "cell_type": "code", - "execution_count": 285, + "execution_count": 445, "metadata": {}, "outputs": [ { @@ -710,7 +712,7 @@ }, { "cell_type": "code", - "execution_count": 286, + "execution_count": 446, "metadata": {}, "outputs": [ { @@ -761,7 +763,7 @@ }, { "cell_type": "code", - "execution_count": 287, + "execution_count": 480, "metadata": {}, "outputs": [ { @@ -814,7 +816,7 @@ }, { "cell_type": "code", - "execution_count": 324, + "execution_count": 448, "metadata": {}, "outputs": [], "source": [ @@ -855,7 +857,7 @@ }, { "cell_type": "code", - "execution_count": 289, + "execution_count": 449, "metadata": {}, "outputs": [], "source": [ @@ -865,7 +867,7 @@ }, { "cell_type": "code", - "execution_count": 290, + "execution_count": 450, "metadata": {}, "outputs": [], "source": [ @@ -887,7 +889,7 @@ }, { "cell_type": "code", - "execution_count": 366, + "execution_count": 451, "metadata": {}, "outputs": [], "source": [ @@ -938,7 +940,7 @@ }, { "cell_type": "code", - "execution_count": 292, + "execution_count": 452, "metadata": {}, "outputs": [], "source": [ @@ -954,7 +956,7 @@ }, { "cell_type": "code", - "execution_count": 293, + "execution_count": 453, "metadata": {}, "outputs": [], "source": [ @@ -1002,77 +1004,7 @@ }, { "cell_type": "code", - "execution_count": 294, - "metadata": {}, - "outputs": [], - "source": [ - "# channel_names = ['CD20']\n", - "channel_names = ['CD68', 'CD163', 'HLA-DR', 'CD11c', 'CD14', 'CD16']\n", - "channel_indices = [channel_name_map[name] for name in channel_names]" - ] - }, - { - "cell_type": "code", - "execution_count": 295, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Clinical input variables: (404, 11), Number of missing values: 70\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/var/folders/p9/37n8_h0j3w136cfjm88xkpcr0000gn/T/ipykernel_33714/3982375415.py:12: SettingWithCopyWarning: \n", - "A value is trying to be set on a copy of a slice from a DataFrame.\n", - "Try using .loc[row_indexer,col_indexer] = value instead\n", - "\n", - "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", - " filtered_df[progression_column] = filtered_df[progression_column].astype(int)\n", - "/var/folders/p9/37n8_h0j3w136cfjm88xkpcr0000gn/T/ipykernel_33714/3982375415.py:22: SettingWithCopyWarning: \n", - "A value is trying to be set on a copy of a slice from a DataFrame\n", - "\n", - "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", - " filtered_df.drop(columns=columns_to_drop, inplace=True)\n" - ] - } - ], - "source": [ - "file_path = './LungData/LUAD Clinical Data.xlsx'\n", - "progression_column = 'Progression (No: 0, Yes: 1)'\n", - "\n", - "clinical_variables_input, missing_values_mask = get_clinical_variables_input_and_mask(file_path, progression_column, sheet_name='LUAD_416_Discovery') #, columns=['Sex (Male: 0, Female: 1)'])#, 'Age (<75: 0, ≥75: 1)', 'Predominant histological pattern (Lepidic:1, Papillary: 2, Acinar: 3, Micropapillary: 4, Solid: 5)'])\n", - "print(f'Clinical input variables: {clinical_variables_input.shape}, Number of missing values: {missing_values_mask.sum()}')" - ] - }, - { - "cell_type": "code", - "execution_count": 296, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(404, 12288)" - ] - }, - "execution_count": 296, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "resnet_embeddings = get_embeddings_for_subset_from_file('embeddings_all_markers_discovery.csv', channel_indices)\n", - "resnet_embeddings.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 297, + "execution_count": 454, "metadata": {}, "outputs": [], "source": [ @@ -1100,7 +1032,7 @@ }, { "cell_type": "code", - "execution_count": 298, + "execution_count": 455, "metadata": {}, "outputs": [ { @@ -1109,7 +1041,7 @@ "((404, 12288), (404, 11), (404,))" ] }, - "execution_count": 298, + "execution_count": 455, "metadata": {}, "output_type": "execute_result" } @@ -1128,69 +1060,15 @@ " # Calculate baseline for this subset\n", " baseline = 1 - np.mean(y)\n", " \n", - " return X_res, X_tab, y, baseline\n", - "\n", - "X_res, X_tab, y, baseline = remove_rows_with_nan(resnet_embeddings, clinical_variables_input, progression_labels)\n", - "X_res.shape, X_tab.shape, y.shape\n", - "\n", - "resnet_embeddings.shape, clinical_variables_input.shape, progression_labels.shape" + " return X_res, X_tab, y, baseline" ] }, { "cell_type": "code", - "execution_count": 299, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "((334, 12288), (334, 11), (334,))" - ] - }, - "execution_count": 299, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "X_res.shape, X_tab.shape, y.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 300, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(numpy.ndarray, numpy.ndarray, numpy.ndarray)" - ] - }, - "execution_count": 300, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "X_res_resampled, X_tab_resampled, y_resampled = get_oversampled_dataset(X_res, X_tab, y)\n", - "X_res_resampled.shape, X_tab_resampled.shape, y_resampled.shape\n", - "type(X_res_resampled), type(X_tab_resampled), type(y_resampled)" - ] - }, - { - "cell_type": "code", - "execution_count": 329, + "execution_count": 456, "metadata": {}, "outputs": [], "source": [ - "\"\"\"\n", - "1. ResNet embeddings\n", - "2. ResNet embeddings + Raw tabular\n", - "3. ResNet embeddings + MPL embedded tabular\n", - "\"\"\"\n", - "\n", - "\n", "def train_multimodal_concat(X_res, y, X_tab=None, tabular=None):\n", " \n", " if tabular == \"mlp\":\n", @@ -1236,10 +1114,11 @@ " recall_scores.append(recall)\n", "\n", " # Print metrics\n", - " print(f'Confusion matrix: \\n{conf_matrix}')\n", - " print(f'Fold {k+1} test accuracy: {accuracy}')\n", - " print(f'Fold {k+1} test precision: {precision}')\n", - " print(f'Fold {k+1} test recall: {recall}')\n", + " # print(f'Confusion matrix: \\n{conf_matrix}')\n", + " print(f'Fold {k+1}')\n", + " print(f'Test accuracy: {accuracy}')\n", + " print(f'Test precision: {precision}')\n", + " print(f'Test recall: {recall}\\n')\n", "\n", " # Print average metrics\n", " print(f'Baseline: {1 - np.mean(progression_labels)}')\n", @@ -1251,7 +1130,7 @@ }, { "cell_type": "code", - "execution_count": 322, + "execution_count": 457, "metadata": {}, "outputs": [], "source": [ @@ -1327,48 +1206,178 @@ }, { "cell_type": "code", - "execution_count": 323, + "execution_count": 466, + "metadata": {}, + "outputs": [], + "source": [ + "# channel_names = ['CD20']\n", + "channel_names = ['CD68', 'CD163', 'HLA-DR', 'CD11c', 'CD14', 'CD16']\n", + "channel_indices = [channel_name_map[name] for name in channel_names]" + ] + }, + { + "cell_type": "code", + "execution_count": 476, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(404, 12288)" + ] + }, + "execution_count": 476, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "resnet_embeddings = get_embeddings_for_subset_from_file('embeddings_all_markers_discovery.csv', channel_indices)\n", + "resnet_embeddings.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 467, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Confusion matrix: \n", - "[[56 1]\n", - " [ 3 55]]\n", - "Fold 1 test accuracy: 0.9652173913043478\n", - "Fold 1 test precision: 0.9821428571428571\n", - "Fold 1 test recall: 0.9482758620689655\n", - "Confusion matrix: \n", - "[[58 5]\n", - " [ 0 52]]\n", - "Fold 2 test accuracy: 0.9565217391304348\n", - "Fold 2 test precision: 0.9122807017543859\n", - "Fold 2 test recall: 1.0\n", - "Confusion matrix: \n", - "[[56 1]\n", - " [ 0 57]]\n", - "Fold 3 test accuracy: 0.9912280701754386\n", - "Fold 3 test precision: 0.9827586206896551\n", - "Fold 3 test recall: 1.0\n", - "Confusion matrix: \n", - "[[55 0]\n", - " [ 3 56]]\n", - "Fold 4 test accuracy: 0.9736842105263158\n", - "Fold 4 test precision: 1.0\n", - "Fold 4 test recall: 0.9491525423728814\n", - "Confusion matrix: \n", - "[[53 1]\n", - " [ 0 60]]\n", - "Fold 5 test accuracy: 0.9912280701754386\n", - "Fold 5 test precision: 0.9836065573770492\n", - "Fold 5 test recall: 1.0\n", - "Baseline: 0.8415841584158416\n", - "Average test accuracy: 0.975575896262395\n", - "Standard deviation of accuracies: 0.013884661323256872\n", - "Average test precision: 0.9721577473927894\n", - "Average test recall: 0.9794856808883694\n" + "Clinical input variables: (404, 11), Number of missing values: 70\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/var/folders/p9/37n8_h0j3w136cfjm88xkpcr0000gn/T/ipykernel_33714/3982375415.py:12: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame.\n", + "Try using .loc[row_indexer,col_indexer] = value instead\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " filtered_df[progression_column] = filtered_df[progression_column].astype(int)\n", + "/var/folders/p9/37n8_h0j3w136cfjm88xkpcr0000gn/T/ipykernel_33714/3982375415.py:22: SettingWithCopyWarning: \n", + "A value is trying to be set on a copy of a slice from a DataFrame\n", + "\n", + "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n", + " filtered_df.drop(columns=columns_to_drop, inplace=True)\n" + ] + } + ], + "source": [ + "file_path = './LungData/LUAD Clinical Data.xlsx'\n", + "progression_column = 'Progression (No: 0, Yes: 1)'\n", + "\n", + "clinical_variables_input, missing_values_mask = get_clinical_variables_input_and_mask(file_path, progression_column, sheet_name='LUAD_416_Discovery') #, columns=['Sex (Male: 0, Female: 1)'])#, 'Age (<75: 0, ≥75: 1)', 'Predominant histological pattern (Lepidic:1, Papillary: 2, Acinar: 3, Micropapillary: 4, Solid: 5)'])\n", + "print(f'Clinical input variables: {clinical_variables_input.shape}, Number of missing values: {missing_values_mask.sum()}')" + ] + }, + { + "cell_type": "code", + "execution_count": 475, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((334, 12288), (334, 11), (334,))" + ] + }, + "execution_count": 475, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X_res, X_tab, y, baseline = remove_rows_with_nan(resnet_embeddings, clinical_variables_input, progression_labels)\n", + "X_res.shape, X_tab.shape, y.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 472, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "((572, 12288), (572, 11), (572,))" + ] + }, + "execution_count": 472, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "X_res_resampled, X_tab_resampled, y_resampled = get_oversampled_dataset(X_res, X_tab, y)\n", + "X_res_resampled.shape, X_tab_resampled.shape, y_resampled.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 504, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Baseline for Sex (Male: 0, Female: 1): 84.16\n", + "Baseline for Age (<75: 0, ≥75: 1): 84.16\n", + "Baseline for BMI (<30: 0, ≥30: 1): 84.16\n", + "Baseline for Smoking Status (Smoker: 0, Non-smoker:1): 84.37\n", + "Baseline for Pack Years (1-30: 0, ≥30: 1): 85.37\n", + "Baseline for Stage (I-II: 0, III-IV:1): 84.37\n", + "Baseline for Predominant histological pattern (Lepidic:1, Papillary: 2, Acinar: 3, Micropapillary: 4, Solid: 5): 84.16\n" + ] + } + ], + "source": [ + "clinical_variables_input_column_names = ['Sex (Male: 0, Female: 1)','Age (<75: 0, ≥75: 1)','BMI (<30: 0, ≥30: 1)','Smoking Status (Smoker: 0, Non-smoker:1)','Pack Years (1-30: 0, ≥30: 1)','Stage (I-II: 0, III-IV:1)','Predominant histological pattern (Lepidic:1, Papillary: 2, Acinar: 3, Micropapillary: 4, Solid: 5)']\n", + "\n", + "for clinical_variable in clinical_variables_input_column_names:\n", + " # print(clinical_variable)\n", + " clinical_variables_input, missing_values_mask = get_clinical_variables_input_and_mask(file_path, progression_column, sheet_name='LUAD_416_Discovery', columns=[clinical_variable])\n", + " X_res, X_tab, y, baseline = remove_rows_with_nan(resnet_embeddings, clinical_variables_input, progression_labels)\n", + " baseline = (1 - np.mean(y))*100\n", + " print(f\"Baseline for {clinical_variable}: {baseline:.2f}\")\n", + " # X_res_resampled, X_tab_resampled, y_resampled = get_oversampled_dataset(X_res, X_tab, y)\n", + "\n", + " # train_multimodal_concat(X_res_resampled, y_resampled, X_tab_resampled, tabular=\"raw\")" + ] + }, + { + "cell_type": "code", + "execution_count": 474, + "metadata": {}, + "outputs": [ + { + "ename": "KeyboardInterrupt", + "evalue": "", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", + "Cell \u001b[0;32mIn[474], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39m# TODO fix ValueError: X has 11 features, but MiniBatchSparsePCA is expecting 12288 features as input.\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[39m# TODO Fix train svm pca model function to split train and test\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m train_multimodal_aggregate(X_res_resampled, X_tab_resampled, y_resampled)\n", + "Cell \u001b[0;32mIn[457], line 22\u001b[0m, in \u001b[0;36mtrain_multimodal_aggregate\u001b[0;34m(X_res, X_tab, y)\u001b[0m\n\u001b[1;32m 19\u001b[0m y_train, y_test \u001b[39m=\u001b[39m y[train_index], y[test_index]\n\u001b[1;32m 21\u001b[0m \u001b[39m# Train PCA and SVM on images\u001b[39;00m\n\u001b[0;32m---> 22\u001b[0m pca_model, svm_model \u001b[39m=\u001b[39m train_PCA_SVM_clf(X_res_train, y_train)\n\u001b[1;32m 23\u001b[0m \u001b[39m# Get predicted probabilities from the SVM model\u001b[39;00m\n\u001b[1;32m 24\u001b[0m predicted_probabilities_image_train \u001b[39m=\u001b[39m svm_model\u001b[39m.\u001b[39mpredict_proba(pca_model\u001b[39m.\u001b[39mtransform(X_res_train))\n", + "Cell \u001b[0;32mIn[448], line 6\u001b[0m, in \u001b[0;36mtrain_PCA_SVM_clf\u001b[0;34m(X_train, y_train)\u001b[0m\n\u001b[1;32m 3\u001b[0m pca_model \u001b[39m=\u001b[39m MiniBatchSparsePCA(n_components\u001b[39m=\u001b[39m\u001b[39m9\u001b[39m, batch_size\u001b[39m=\u001b[39m\u001b[39m500\u001b[39m,random_state\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m)\n\u001b[1;32m 5\u001b[0m \u001b[39m# Train PCA model\u001b[39;00m\n\u001b[0;32m----> 6\u001b[0m pca_model\u001b[39m.\u001b[39;49mfit(X_train)\n\u001b[1;32m 8\u001b[0m \u001b[39m# Get PCA embeddings of training data\u001b[39;00m\n\u001b[1;32m 9\u001b[0m X_tr_em \u001b[39m=\u001b[39m pca_model\u001b[39m.\u001b[39mtransform(X_train)\n", + "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/decomposition/_sparse_pca.py:85\u001b[0m, in \u001b[0;36m_BaseSparsePCA.fit\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 82\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 83\u001b[0m n_components \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_components\n\u001b[0;32m---> 85\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_fit(X, n_components, random_state)\n", + "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/decomposition/_sparse_pca.py:527\u001b[0m, in \u001b[0;36mMiniBatchSparsePCA._fit\u001b[0;34m(self, X, n_components, random_state)\u001b[0m\n\u001b[1;32m 524\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"Specialized `fit` for MiniBatchSparsePCA.\"\"\"\u001b[39;00m\n\u001b[1;32m 526\u001b[0m transform_algorithm \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mlasso_\u001b[39m\u001b[39m\"\u001b[39m \u001b[39m+\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmethod\n\u001b[0;32m--> 527\u001b[0m est \u001b[39m=\u001b[39m MiniBatchDictionaryLearning(\n\u001b[1;32m 528\u001b[0m n_components\u001b[39m=\u001b[39;49mn_components,\n\u001b[1;32m 529\u001b[0m alpha\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49malpha,\n\u001b[1;32m 530\u001b[0m n_iter\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mn_iter,\n\u001b[1;32m 531\u001b[0m max_iter\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmax_iter,\n\u001b[1;32m 532\u001b[0m dict_init\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m 533\u001b[0m batch_size\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbatch_size,\n\u001b[1;32m 534\u001b[0m shuffle\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mshuffle,\n\u001b[1;32m 535\u001b[0m n_jobs\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mn_jobs,\n\u001b[1;32m 536\u001b[0m fit_algorithm\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmethod,\n\u001b[1;32m 537\u001b[0m random_state\u001b[39m=\u001b[39;49mrandom_state,\n\u001b[1;32m 538\u001b[0m transform_algorithm\u001b[39m=\u001b[39;49mtransform_algorithm,\n\u001b[1;32m 539\u001b[0m transform_alpha\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49malpha,\n\u001b[1;32m 540\u001b[0m verbose\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mverbose,\n\u001b[1;32m 541\u001b[0m callback\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcallback,\n\u001b[1;32m 542\u001b[0m tol\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtol,\n\u001b[1;32m 543\u001b[0m max_no_improvement\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmax_no_improvement,\n\u001b[1;32m 544\u001b[0m )\u001b[39m.\u001b[39;49mfit(X\u001b[39m.\u001b[39;49mT)\n\u001b[1;32m 546\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcomponents_, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_iter_ \u001b[39m=\u001b[39m est\u001b[39m.\u001b[39mtransform(X\u001b[39m.\u001b[39mT)\u001b[39m.\u001b[39mT, est\u001b[39m.\u001b[39mn_iter_\n\u001b[1;32m 548\u001b[0m components_norm \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39mlinalg\u001b[39m.\u001b[39mnorm(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcomponents_, axis\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)[:, np\u001b[39m.\u001b[39mnewaxis]\n", + "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/decomposition/_dict_learning.py:2377\u001b[0m, in \u001b[0;36mMiniBatchDictionaryLearning.fit\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 2374\u001b[0m batches \u001b[39m=\u001b[39m itertools\u001b[39m.\u001b[39mcycle(batches)\n\u001b[1;32m 2376\u001b[0m \u001b[39mfor\u001b[39;00m i, batch \u001b[39min\u001b[39;00m \u001b[39mzip\u001b[39m(\u001b[39mrange\u001b[39m(n_iter), batches):\n\u001b[0;32m-> 2377\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_minibatch_step(X_train[batch], dictionary, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_random_state, i)\n\u001b[1;32m 2379\u001b[0m trigger_verbose \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mverbose \u001b[39mand\u001b[39;00m i \u001b[39m%\u001b[39m ceil(\u001b[39m100.0\u001b[39m \u001b[39m/\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mverbose) \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m\n\u001b[1;32m 2380\u001b[0m \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mverbose \u001b[39m>\u001b[39m \u001b[39m10\u001b[39m \u001b[39mor\u001b[39;00m trigger_verbose:\n", + "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/decomposition/_dict_learning.py:2166\u001b[0m, in \u001b[0;36mMiniBatchDictionaryLearning._minibatch_step\u001b[0;34m(self, X, dictionary, random_state, step)\u001b[0m\n\u001b[1;32m 2163\u001b[0m batch_size \u001b[39m=\u001b[39m X\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m]\n\u001b[1;32m 2165\u001b[0m \u001b[39m# Compute code for this batch\u001b[39;00m\n\u001b[0;32m-> 2166\u001b[0m code \u001b[39m=\u001b[39m sparse_encode(\n\u001b[1;32m 2167\u001b[0m X,\n\u001b[1;32m 2168\u001b[0m dictionary,\n\u001b[1;32m 2169\u001b[0m algorithm\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_fit_algorithm,\n\u001b[1;32m 2170\u001b[0m alpha\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49malpha,\n\u001b[1;32m 2171\u001b[0m n_jobs\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mn_jobs,\n\u001b[1;32m 2172\u001b[0m check_input\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m,\n\u001b[1;32m 2173\u001b[0m positive\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mpositive_code,\n\u001b[1;32m 2174\u001b[0m max_iter\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtransform_max_iter,\n\u001b[1;32m 2175\u001b[0m verbose\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mverbose,\n\u001b[1;32m 2176\u001b[0m )\n\u001b[1;32m 2178\u001b[0m batch_cost \u001b[39m=\u001b[39m (\n\u001b[1;32m 2179\u001b[0m \u001b[39m0.5\u001b[39m \u001b[39m*\u001b[39m ((X \u001b[39m-\u001b[39m code \u001b[39m@\u001b[39m dictionary) \u001b[39m*\u001b[39m\u001b[39m*\u001b[39m \u001b[39m2\u001b[39m)\u001b[39m.\u001b[39msum()\n\u001b[1;32m 2180\u001b[0m \u001b[39m+\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39malpha \u001b[39m*\u001b[39m np\u001b[39m.\u001b[39msum(np\u001b[39m.\u001b[39mabs(code))\n\u001b[1;32m 2181\u001b[0m ) \u001b[39m/\u001b[39m batch_size\n\u001b[1;32m 2183\u001b[0m \u001b[39m# Update inner stats\u001b[39;00m\n", + "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/decomposition/_dict_learning.py:378\u001b[0m, in \u001b[0;36msparse_encode\u001b[0;34m(X, dictionary, gram, cov, algorithm, n_nonzero_coefs, alpha, copy_cov, init, max_iter, n_jobs, check_input, verbose, positive)\u001b[0m\n\u001b[1;32m 375\u001b[0m regularization \u001b[39m=\u001b[39m \u001b[39m1.0\u001b[39m\n\u001b[1;32m 377\u001b[0m \u001b[39mif\u001b[39;00m effective_n_jobs(n_jobs) \u001b[39m==\u001b[39m \u001b[39m1\u001b[39m \u001b[39mor\u001b[39;00m algorithm \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mthreshold\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[0;32m--> 378\u001b[0m code \u001b[39m=\u001b[39m _sparse_encode(\n\u001b[1;32m 379\u001b[0m X,\n\u001b[1;32m 380\u001b[0m dictionary,\n\u001b[1;32m 381\u001b[0m gram,\n\u001b[1;32m 382\u001b[0m cov\u001b[39m=\u001b[39;49mcov,\n\u001b[1;32m 383\u001b[0m algorithm\u001b[39m=\u001b[39;49malgorithm,\n\u001b[1;32m 384\u001b[0m regularization\u001b[39m=\u001b[39;49mregularization,\n\u001b[1;32m 385\u001b[0m copy_cov\u001b[39m=\u001b[39;49mcopy_cov,\n\u001b[1;32m 386\u001b[0m init\u001b[39m=\u001b[39;49minit,\n\u001b[1;32m 387\u001b[0m max_iter\u001b[39m=\u001b[39;49mmax_iter,\n\u001b[1;32m 388\u001b[0m check_input\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m,\n\u001b[1;32m 389\u001b[0m verbose\u001b[39m=\u001b[39;49mverbose,\n\u001b[1;32m 390\u001b[0m positive\u001b[39m=\u001b[39;49mpositive,\n\u001b[1;32m 391\u001b[0m )\n\u001b[1;32m 392\u001b[0m \u001b[39mreturn\u001b[39;00m code\n\u001b[1;32m 394\u001b[0m \u001b[39m# Enter parallel code block\u001b[39;00m\n", + "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/decomposition/_dict_learning.py:156\u001b[0m, in \u001b[0;36m_sparse_encode\u001b[0;34m(X, dictionary, gram, cov, algorithm, regularization, copy_cov, init, max_iter, check_input, verbose, positive)\u001b[0m\n\u001b[1;32m 145\u001b[0m \u001b[39m# Not passing in verbose=max(0, verbose-1) because Lars.fit already\u001b[39;00m\n\u001b[1;32m 146\u001b[0m \u001b[39m# corrects the verbosity level.\u001b[39;00m\n\u001b[1;32m 147\u001b[0m lasso_lars \u001b[39m=\u001b[39m LassoLars(\n\u001b[1;32m 148\u001b[0m alpha\u001b[39m=\u001b[39malpha,\n\u001b[1;32m 149\u001b[0m fit_intercept\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 154\u001b[0m max_iter\u001b[39m=\u001b[39mmax_iter,\n\u001b[1;32m 155\u001b[0m )\n\u001b[0;32m--> 156\u001b[0m lasso_lars\u001b[39m.\u001b[39;49mfit(dictionary\u001b[39m.\u001b[39;49mT, X\u001b[39m.\u001b[39;49mT, Xy\u001b[39m=\u001b[39;49mcov)\n\u001b[1;32m 157\u001b[0m new_code \u001b[39m=\u001b[39m lasso_lars\u001b[39m.\u001b[39mcoef_\n\u001b[1;32m 158\u001b[0m \u001b[39mfinally\u001b[39;00m:\n", + "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/linear_model/_least_angle.py:1144\u001b[0m, in \u001b[0;36mLars.fit\u001b[0;34m(self, X, y, Xy)\u001b[0m\n\u001b[1;32m 1141\u001b[0m noise \u001b[39m=\u001b[39m rng\u001b[39m.\u001b[39muniform(high\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mjitter, size\u001b[39m=\u001b[39m\u001b[39mlen\u001b[39m(y))\n\u001b[1;32m 1142\u001b[0m y \u001b[39m=\u001b[39m y \u001b[39m+\u001b[39m noise\n\u001b[0;32m-> 1144\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_fit(\n\u001b[1;32m 1145\u001b[0m X,\n\u001b[1;32m 1146\u001b[0m y,\n\u001b[1;32m 1147\u001b[0m max_iter\u001b[39m=\u001b[39;49mmax_iter,\n\u001b[1;32m 1148\u001b[0m alpha\u001b[39m=\u001b[39;49malpha,\n\u001b[1;32m 1149\u001b[0m fit_path\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mfit_path,\n\u001b[1;32m 1150\u001b[0m normalize\u001b[39m=\u001b[39;49m_normalize,\n\u001b[1;32m 1151\u001b[0m Xy\u001b[39m=\u001b[39;49mXy,\n\u001b[1;32m 1152\u001b[0m )\n\u001b[1;32m 1154\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\n", + "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/linear_model/_least_angle.py:1077\u001b[0m, in \u001b[0;36mLars._fit\u001b[0;34m(self, X, y, max_iter, alpha, fit_path, normalize, Xy)\u001b[0m\n\u001b[1;32m 1075\u001b[0m \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(n_targets):\n\u001b[1;32m 1076\u001b[0m this_Xy \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m \u001b[39mif\u001b[39;00m Xy \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m Xy[:, k]\n\u001b[0;32m-> 1077\u001b[0m alphas, _, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcoef_[k], n_iter_ \u001b[39m=\u001b[39m lars_path(\n\u001b[1;32m 1078\u001b[0m X,\n\u001b[1;32m 1079\u001b[0m y[:, k],\n\u001b[1;32m 1080\u001b[0m Gram\u001b[39m=\u001b[39;49mGram,\n\u001b[1;32m 1081\u001b[0m Xy\u001b[39m=\u001b[39;49mthis_Xy,\n\u001b[1;32m 1082\u001b[0m copy_X\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcopy_X,\n\u001b[1;32m 1083\u001b[0m copy_Gram\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 1084\u001b[0m alpha_min\u001b[39m=\u001b[39;49malpha,\n\u001b[1;32m 1085\u001b[0m method\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmethod,\n\u001b[1;32m 1086\u001b[0m verbose\u001b[39m=\u001b[39;49m\u001b[39mmax\u001b[39;49m(\u001b[39m0\u001b[39;49m, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mverbose \u001b[39m-\u001b[39;49m \u001b[39m1\u001b[39;49m),\n\u001b[1;32m 1087\u001b[0m max_iter\u001b[39m=\u001b[39;49mmax_iter,\n\u001b[1;32m 1088\u001b[0m eps\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49meps,\n\u001b[1;32m 1089\u001b[0m return_path\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m,\n\u001b[1;32m 1090\u001b[0m return_n_iter\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m 1091\u001b[0m positive\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mpositive,\n\u001b[1;32m 1092\u001b[0m )\n\u001b[1;32m 1093\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39malphas_\u001b[39m.\u001b[39mappend(alphas)\n\u001b[1;32m 1094\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_iter_\u001b[39m.\u001b[39mappend(n_iter_)\n", + "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/linear_model/_least_angle.py:170\u001b[0m, in \u001b[0;36mlars_path\u001b[0;34m(X, y, Xy, Gram, max_iter, alpha_min, method, copy_X, eps, copy_Gram, verbose, return_path, return_n_iter, positive)\u001b[0m\n\u001b[1;32m 165\u001b[0m \u001b[39mif\u001b[39;00m X \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m Gram \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 166\u001b[0m \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m 167\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mX cannot be None if Gram is not None\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 168\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mUse lars_path_gram to avoid passing X and y.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m 169\u001b[0m )\n\u001b[0;32m--> 170\u001b[0m \u001b[39mreturn\u001b[39;00m _lars_path_solver(\n\u001b[1;32m 171\u001b[0m X\u001b[39m=\u001b[39;49mX,\n\u001b[1;32m 172\u001b[0m y\u001b[39m=\u001b[39;49my,\n\u001b[1;32m 173\u001b[0m Xy\u001b[39m=\u001b[39;49mXy,\n\u001b[1;32m 174\u001b[0m Gram\u001b[39m=\u001b[39;49mGram,\n\u001b[1;32m 175\u001b[0m n_samples\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m 176\u001b[0m max_iter\u001b[39m=\u001b[39;49mmax_iter,\n\u001b[1;32m 177\u001b[0m alpha_min\u001b[39m=\u001b[39;49malpha_min,\n\u001b[1;32m 178\u001b[0m method\u001b[39m=\u001b[39;49mmethod,\n\u001b[1;32m 179\u001b[0m copy_X\u001b[39m=\u001b[39;49mcopy_X,\n\u001b[1;32m 180\u001b[0m eps\u001b[39m=\u001b[39;49meps,\n\u001b[1;32m 181\u001b[0m copy_Gram\u001b[39m=\u001b[39;49mcopy_Gram,\n\u001b[1;32m 182\u001b[0m verbose\u001b[39m=\u001b[39;49mverbose,\n\u001b[1;32m 183\u001b[0m return_path\u001b[39m=\u001b[39;49mreturn_path,\n\u001b[1;32m 184\u001b[0m return_n_iter\u001b[39m=\u001b[39;49mreturn_n_iter,\n\u001b[1;32m 185\u001b[0m positive\u001b[39m=\u001b[39;49mpositive,\n\u001b[1;32m 186\u001b[0m )\n", + "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/linear_model/_least_angle.py:729\u001b[0m, in \u001b[0;36m_lars_path_solver\u001b[0;34m(X, y, Xy, Gram, n_samples, max_iter, alpha_min, method, copy_X, eps, copy_Gram, verbose, return_path, return_n_iter, positive)\u001b[0m\n\u001b[1;32m 725\u001b[0m corr_eq_dir \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39mdot(Gram[:n_active, n_active:]\u001b[39m.\u001b[39mT, least_squares)\n\u001b[1;32m 727\u001b[0m \u001b[39m# Explicit rounding can be necessary to avoid `np.argmax(Cov)` yielding\u001b[39;00m\n\u001b[1;32m 728\u001b[0m \u001b[39m# unstable results because of rounding errors.\u001b[39;00m\n\u001b[0;32m--> 729\u001b[0m np\u001b[39m.\u001b[39;49maround(corr_eq_dir, decimals\u001b[39m=\u001b[39;49mcov_precision, out\u001b[39m=\u001b[39;49mcorr_eq_dir)\n\u001b[1;32m 731\u001b[0m g1 \u001b[39m=\u001b[39m arrayfuncs\u001b[39m.\u001b[39mmin_pos((C \u001b[39m-\u001b[39m Cov) \u001b[39m/\u001b[39m (AA \u001b[39m-\u001b[39m corr_eq_dir \u001b[39m+\u001b[39m tiny32))\n\u001b[1;32m 732\u001b[0m \u001b[39mif\u001b[39;00m positive:\n", + "File \u001b[0;32m<__array_function__ internals>:200\u001b[0m, in \u001b[0;36maround\u001b[0;34m(*args, **kwargs)\u001b[0m\n", + "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/numpy/core/fromnumeric.py:3337\u001b[0m, in \u001b[0;36maround\u001b[0;34m(a, decimals, out)\u001b[0m\n\u001b[1;32m 3245\u001b[0m \u001b[39m@array_function_dispatch\u001b[39m(_around_dispatcher)\n\u001b[1;32m 3246\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39maround\u001b[39m(a, decimals\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m, out\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m):\n\u001b[1;32m 3247\u001b[0m \u001b[39m \u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m 3248\u001b[0m \u001b[39m Evenly round to the given number of decimals.\u001b[39;00m\n\u001b[1;32m 3249\u001b[0m \n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 3335\u001b[0m \n\u001b[1;32m 3336\u001b[0m \u001b[39m \"\"\"\u001b[39;00m\n\u001b[0;32m-> 3337\u001b[0m \u001b[39mreturn\u001b[39;00m _wrapfunc(a, \u001b[39m'\u001b[39;49m\u001b[39mround\u001b[39;49m\u001b[39m'\u001b[39;49m, decimals\u001b[39m=\u001b[39;49mdecimals, out\u001b[39m=\u001b[39;49mout)\n", + "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/numpy/core/fromnumeric.py:57\u001b[0m, in \u001b[0;36m_wrapfunc\u001b[0;34m(obj, method, *args, **kwds)\u001b[0m\n\u001b[1;32m 54\u001b[0m \u001b[39mreturn\u001b[39;00m _wrapit(obj, method, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwds)\n\u001b[1;32m 56\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m---> 57\u001b[0m \u001b[39mreturn\u001b[39;00m bound(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwds)\n\u001b[1;32m 58\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mTypeError\u001b[39;00m:\n\u001b[1;32m 59\u001b[0m \u001b[39m# A TypeError occurs if the object does have such a method in its\u001b[39;00m\n\u001b[1;32m 60\u001b[0m \u001b[39m# class, but its signature is not identical to that of NumPy's. This\u001b[39;00m\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 64\u001b[0m \u001b[39m# Call _wrapit from within the except clause to ensure a potential\u001b[39;00m\n\u001b[1;32m 65\u001b[0m \u001b[39m# exception has a traceback chain.\u001b[39;00m\n\u001b[1;32m 66\u001b[0m \u001b[39mreturn\u001b[39;00m _wrapit(obj, method, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwds)\n", + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], @@ -1380,7 +1389,7 @@ }, { "cell_type": "code", - "execution_count": 330, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1469,7 +1478,7 @@ }, { "cell_type": "code", - "execution_count": 367, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -1536,6 +1545,168 @@ "print(f'Average test precision: {np.mean(precision_scores)}')\n", "print(f'Average test recall: {np.mean(recall_scores)}')" ] + }, + { + "cell_type": "code", + "execution_count": 499, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Header: Sex (Male: 0, Female: 1)\n", + "Accuracies: ['94.85', '93.38', '94.12', '94.85', '93.38']\n", + "Average Accuracy: 94.12\n", + "Standard Deviation of Accuracies: 0.66\n", + "Average Precision: 90.29\n", + "Average Recall: 98.84\n", + "\n", + "Latex string\n", + "& 94.85 & 93.38 & 94.12 & 94.85 & 93.38 & 94.12 & 0.66 & 90.29 & 98.84\n", + "\n", + "Header: Age (<75: 0, ≥75: 1)\n", + "Accuracies: ['94.85', '93.38', '94.12', '95.59', '93.38']\n", + "Average Accuracy: 94.26\n", + "Standard Deviation of Accuracies: 0.86\n", + "Average Precision: 90.53\n", + "Average Recall: 98.84\n", + "\n", + "Latex string\n", + "& 94.85 & 93.38 & 94.12 & 95.59 & 93.38 & 94.26 & 0.86 & 90.53 & 98.84\n", + "\n", + "Header: BMI (<30: 0, ≥30: 1)\n", + "Accuracies: ['94.85', '93.38', '94.12', '94.85', '93.38']\n", + "Average Accuracy: 94.12\n", + "Standard Deviation of Accuracies: 0.66\n", + "Average Precision: 90.29\n", + "Average Recall: 98.84\n", + "\n", + "Latex string\n", + "& 94.85 & 93.38 & 94.12 & 94.85 & 93.38 & 94.12 & 0.66 & 90.29 & 98.84\n", + "\n", + "Header: Smoking Status (Smoker: 0, Non-smoker:1)\n", + "Accuracies: ['91.91', '96.32', '95.59', '97.06', '97.06']\n", + "Average Accuracy: 95.59\n", + "Standard Deviation of Accuracies: 1.92\n", + "Average Precision: 92.54\n", + "Average Recall: 99.38\n", + "\n", + "Latex string\n", + "& 91.91 & 96.32 & 95.59 & 97.06 & 97.06 & 95.59 & 1.92 & 92.54 & 99.38\n", + "\n", + "Header: Pack Years (1-30: 0, ≥30: 1)\n", + "Accuracies: ['95.65', '93.91', '94.74', '93.86', '98.25']\n", + "Average Accuracy: 95.28\n", + "Standard Deviation of Accuracies: 1.62\n", + "Average Precision: 93.01\n", + "Average Recall: 97.95\n", + "\n", + "Latex string\n", + "& 95.65 & 93.91 & 94.74 & 93.86 & 98.25 & 95.28 & 1.62 & 93.01 & 97.95\n", + "\n", + "Header: Stage (I-II: 0, III-IV:1)\n", + "Accuracies: ['93.38', '96.32', '94.12', '97.79', '96.32']\n", + "Average Accuracy: 95.59\n", + "Standard Deviation of Accuracies: 1.61\n", + "Average Precision: 92.63\n", + "Average Recall: 99.09\n", + "\n", + "Latex string\n", + "& 93.38 & 96.32 & 94.12 & 97.79 & 96.32 & 95.59 & 1.61 & 92.63 & 99.09\n", + "\n", + "Header: Predominant histological pattern (Lepidic:1, Papillary: 2, Acinar: 3, Micropapillary: 4, Solid: 5)\n", + "Accuracies: ['97.79', '94.85', '94.12', '97.06', '94.85']\n", + "Average Accuracy: 95.74\n", + "Standard Deviation of Accuracies: 1.43\n", + "Average Precision: 93.06\n", + "Average Recall: 98.84\n", + "\n", + "Latex string\n", + "& 97.79 & 94.85 & 94.12 & 97.06 & 94.85 & 95.74 & 1.43 & 93.06 & 98.84\n", + "\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "\n", + "# Define the headers\n", + "headers = [\n", + " 'Sex (Male: 0, Female: 1)',\n", + " 'Age (<75: 0, ≥75: 1)',\n", + " 'BMI (<30: 0, ≥30: 1)',\n", + " 'Smoking Status (Smoker: 0, Non-smoker:1)',\n", + " 'Pack Years (1-30: 0, ≥30: 1)',\n", + " 'Stage (I-II: 0, III-IV:1)',\n", + " 'Predominant histological pattern (Lepidic:1, Papillary: 2, Acinar: 3, Micropapillary: 4, Solid: 5)'\n", + "]\n", + "\n", + "# Initialize dictionaries to store results\n", + "results = {header: {'accuracies': [], 'precisions': [], 'recalls': []} for header in headers}\n", + "\n", + "# Read the file\n", + "with open('./results_each_clin_var.txt', 'r') as file:\n", + " lines = file.readlines()\n", + "\n", + "# Parse the data\n", + "current_header = None\n", + "for line in lines:\n", + " line = line.strip()\n", + " if line in headers:\n", + " current_header = line\n", + " elif 'Test accuracy' in line:\n", + " accuracy = float(line.split(': ')[1]) * 100\n", + " results[current_header]['accuracies'].append(accuracy)\n", + " elif 'Test precision' in line:\n", + " precision = float(line.split(': ')[1]) * 100\n", + " results[current_header]['precisions'].append(precision)\n", + " elif 'Test recall' in line:\n", + " recall = float(line.split(': ')[1]) * 100\n", + " results[current_header]['recalls'].append(recall)\n", + "\n", + "# Calculate statistics\n", + "for header in headers:\n", + " accuracies = results[header]['accuracies']\n", + " precisions = results[header]['precisions']\n", + " recalls = results[header]['recalls']\n", + " \n", + " avg_accuracy = np.mean(accuracies)\n", + " std_accuracy = np.std(accuracies)\n", + " avg_precision = np.mean(precisions)\n", + " avg_recall = np.mean(recalls)\n", + " \n", + " results[header]['avg_accuracy'] = avg_accuracy\n", + " results[header]['std_accuracy'] = std_accuracy\n", + " results[header]['avg_precision'] = avg_precision\n", + " results[header]['avg_recall'] = avg_recall\n", + "\n", + " formatted_accuracies = [f\"{acc:.2f}\" for acc in accuracies]\n", + "\n", + " print(f\"Header: {header}\")\n", + " print(f\"Accuracies: {formatted_accuracies}\")\n", + " print(f\"Average Accuracy: {avg_accuracy:.2f}\")\n", + " print(f\"Standard Deviation of Accuracies: {std_accuracy:.2f}\")\n", + " print(f\"Average Precision: {avg_precision:.2f}\")\n", + " print(f\"Average Recall: {avg_recall:.2f}\")\n", + " print()\n", + "\n", + " latex_string = \"&\"\n", + " \n", + " for acc in accuracies:\n", + " latex_string += f\" {acc:.2f} &\"\n", + "\n", + " latex_string += f\" {avg_accuracy:.2f} &\"\n", + " latex_string += f\" {std_accuracy:.2f} &\"\n", + " latex_string += f\" {avg_precision:.2f} &\"\n", + " latex_string += f\" {avg_recall:.2f}\"\n", + "\n", + " print('Latex string')\n", + " print(latex_string)\n", + " print()\n", + "\n", + "# Note: Results are in percent as required.\n" + ] } ], "metadata": { diff --git a/results_each_clin_var.txt b/results_each_clin_var.txt new file mode 100644 index 0000000000000000000000000000000000000000..1e985bbecb504e7502cc571d818ec3ec97537929 --- /dev/null +++ b/results_each_clin_var.txt @@ -0,0 +1,223 @@ +Sex (Male: 0, Female: 1) +Fold 1 +Test accuracy: 0.9485294117647058 +Test precision: 0.9066666666666666 +Test recall: 1.0 + +Fold 2 +Test accuracy: 0.9338235294117647 +Test precision: 0.8873239436619719 +Test recall: 0.984375 + +Fold 3 +Test accuracy: 0.9411764705882353 +Test precision: 0.8904109589041096 +Test recall: 1.0 + +Fold 4 +Test accuracy: 0.9485294117647058 +Test precision: 0.9113924050632911 +Test recall: 1.0 + +Fold 5 +Test accuracy: 0.9338235294117647 +Test precision: 0.918918918918919 +Test recall: 0.9577464788732394 + +Baseline: 0.8415841584158416 +Average test accuracy: 0.9411764705882352 +Standard deviation of accuracies: 0.006576670522058182 +Average test precision: 0.9029425786429917 +Average test recall: 0.9884242957746479 + +Age (<75: 0, ≥75: 1) +Fold 1 +Test accuracy: 0.9485294117647058 +Test precision: 0.9066666666666666 +Test recall: 1.0 + +Fold 2 +Test accuracy: 0.9338235294117647 +Test precision: 0.8873239436619719 +Test recall: 0.984375 + +Fold 3 +Test accuracy: 0.9411764705882353 +Test precision: 0.8904109589041096 +Test recall: 1.0 + +Fold 4 +Test accuracy: 0.9558823529411765 +Test precision: 0.9230769230769231 +Test recall: 1.0 + +Fold 5 +Test accuracy: 0.9338235294117647 +Test precision: 0.918918918918919 +Test recall: 0.9577464788732394 + +Baseline: 0.8415841584158416 +Average test accuracy: 0.9426470588235294 +Standard deviation of accuracies: 0.008574929257125446 +Average test precision: 0.9052794822457182 +Average test recall: 0.9884242957746479 + +BMI (<30: 0, ≥30: 1) +Fold 1 +Test accuracy: 0.9485294117647058 +Test precision: 0.9066666666666666 +Test recall: 1.0 + +Fold 2 +Test accuracy: 0.9338235294117647 +Test precision: 0.8873239436619719 +Test recall: 0.984375 + +Fold 3 +Test accuracy: 0.9411764705882353 +Test precision: 0.8904109589041096 +Test recall: 1.0 + +Fold 4 +Test accuracy: 0.9485294117647058 +Test precision: 0.9113924050632911 +Test recall: 1.0 + +Fold 5 +Test accuracy: 0.9338235294117647 +Test precision: 0.918918918918919 +Test recall: 0.9577464788732394 + +Baseline: 0.8415841584158416 +Average test accuracy: 0.9411764705882352 +Standard deviation of accuracies: 0.006576670522058182 +Average test precision: 0.9029425786429917 +Average test recall: 0.9884242957746479 + +Smoking Status (Smoker: 0, Non-smoker:1) +Fold 1 +Test accuracy: 0.9191176470588235 +Test precision: 0.8658536585365854 +Test recall: 1.0 + +Fold 2 +Test accuracy: 0.9632352941176471 +Test precision: 0.9393939393939394 +Test recall: 0.9841269841269841 + +Fold 3 +Test accuracy: 0.9558823529411765 +Test precision: 0.9178082191780822 +Test recall: 1.0 + +Fold 4 +Test accuracy: 0.9705882352941176 +Test precision: 0.9473684210526315 +Test recall: 1.0 + +Fold 5 +Test accuracy: 0.9705882352941176 +Test precision: 0.9565217391304348 +Test recall: 0.9850746268656716 + +Baseline: 0.8415841584158416 +Average test accuracy: 0.9558823529411764 +Standard deviation of accuracies: 0.01917412472118428 +Average test precision: 0.9253891954583346 +Average test recall: 0.9938403221985311 + +Pack Years (1-30: 0, ≥30: 1) +Fold 1 +Test accuracy: 0.9565217391304348 +Test precision: 0.9649122807017544 +Test recall: 0.9482758620689655 + +Fold 2 +Test accuracy: 0.9391304347826087 +Test precision: 0.8793103448275862 +Test recall: 1.0 + +Fold 3 +Test accuracy: 0.9473684210526315 +Test precision: 0.9047619047619048 +Test recall: 1.0 + +Fold 4 +Test accuracy: 0.9385964912280702 +Test precision: 0.9333333333333333 +Test recall: 0.9491525423728814 + +Fold 5 +Test accuracy: 0.9824561403508771 +Test precision: 0.9682539682539683 +Test recall: 1.0 + +Baseline: 0.8415841584158416 +Average test accuracy: 0.9528146453089246 +Standard deviation of accuracies: 0.0161952811322858 +Average test precision: 0.9301143663757093 +Average test recall: 0.9794856808883694 + +Stage (I-II: 0, III-IV:1) +Fold 1 +Test accuracy: 0.9338235294117647 +Test precision: 0.8846153846153846 +Test recall: 1.0 + +Fold 2 +Test accuracy: 0.9632352941176471 +Test precision: 0.9402985074626866 +Test recall: 0.984375 + +Fold 3 +Test accuracy: 0.9411764705882353 +Test precision: 0.9014084507042254 +Test recall: 0.9846153846153847 + +Fold 4 +Test accuracy: 0.9779411764705882 +Test precision: 0.9605263157894737 +Test recall: 1.0 + +Fold 5 +Test accuracy: 0.9632352941176471 +Test precision: 0.9444444444444444 +Test recall: 0.9855072463768116 + +Baseline: 0.8415841584158416 +Average test accuracy: 0.9558823529411764 +Standard deviation of accuracies: 0.016109486985446054 +Average test precision: 0.9262586206032429 +Average test recall: 0.9908995261984392 + +Predominant histological pattern (Lepidic:1, Papillary: 2, Acinar: 3, Micropapillary: 4, Solid: 5) +Fold 1 +Test accuracy: 0.9779411764705882 +Test precision: 0.9577464788732394 +Test recall: 1.0 + +Fold 2 +Test accuracy: 0.9485294117647058 +Test precision: 0.9130434782608695 +Test recall: 0.984375 + +Fold 3 +Test accuracy: 0.9411764705882353 +Test precision: 0.8904109589041096 +Test recall: 1.0 + +Fold 4 +Test accuracy: 0.9705882352941176 +Test precision: 0.9473684210526315 +Test recall: 1.0 + +Fold 5 +Test accuracy: 0.9485294117647058 +Test precision: 0.9444444444444444 +Test recall: 0.9577464788732394 + +Baseline: 0.8415841584158416 +Average test accuracy: 0.9573529411764705 +Standard deviation of accuracies: 0.01425788193357744 +Average test precision: 0.9306027563070588 +Average test recall: 0.9884242957746479