diff --git a/code.ipynb b/code.ipynb
index 71797ae083d89865d53e162031176f929c5629a7..d50b6f2ed76bf985c5d7e86329287591ef8bfa4d 100644
--- a/code.ipynb
+++ b/code.ipynb
@@ -2,7 +2,7 @@
  "cells": [
   {
    "cell_type": "code",
-   "execution_count": 314,
+   "execution_count": 485,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -30,12 +30,14 @@
     "from sklearn.gaussian_process.kernels import RBF\n",
     "from sklearn.linear_model import LogisticRegression\n",
     "from sklearn.pipeline import make_pipeline\n",
-    "from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score"
+    "from sklearn.metrics import confusion_matrix, accuracy_score, precision_score, recall_score\n",
+    "\n",
+    "pd.options.mode.copy_on_write = True"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 279,
+   "execution_count": 439,
    "metadata": {},
    "outputs": [
     {
@@ -360,7 +362,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 280,
+   "execution_count": 440,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -387,7 +389,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 281,
+   "execution_count": 441,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -482,7 +484,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 282,
+   "execution_count": 442,
    "metadata": {},
    "outputs": [
     {
@@ -512,7 +514,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 283,
+   "execution_count": 443,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -524,7 +526,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 284,
+   "execution_count": 444,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -537,7 +539,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 285,
+   "execution_count": 445,
    "metadata": {},
    "outputs": [
     {
@@ -710,7 +712,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 286,
+   "execution_count": 446,
    "metadata": {},
    "outputs": [
     {
@@ -761,7 +763,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 287,
+   "execution_count": 480,
    "metadata": {},
    "outputs": [
     {
@@ -814,7 +816,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 324,
+   "execution_count": 448,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -855,7 +857,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 289,
+   "execution_count": 449,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -865,7 +867,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 290,
+   "execution_count": 450,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -887,7 +889,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 366,
+   "execution_count": 451,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -938,7 +940,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 292,
+   "execution_count": 452,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -954,7 +956,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 293,
+   "execution_count": 453,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1002,77 +1004,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 294,
-   "metadata": {},
-   "outputs": [],
-   "source": [
-    "# channel_names = ['CD20']\n",
-    "channel_names = ['CD68', 'CD163', 'HLA-DR', 'CD11c', 'CD14', 'CD16']\n",
-    "channel_indices = [channel_name_map[name] for name in channel_names]"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 295,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Clinical input variables: (404, 11), Number of missing values: 70\n"
-     ]
-    },
-    {
-     "name": "stderr",
-     "output_type": "stream",
-     "text": [
-      "/var/folders/p9/37n8_h0j3w136cfjm88xkpcr0000gn/T/ipykernel_33714/3982375415.py:12: SettingWithCopyWarning: \n",
-      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
-      "Try using .loc[row_indexer,col_indexer] = value instead\n",
-      "\n",
-      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
-      "  filtered_df[progression_column] = filtered_df[progression_column].astype(int)\n",
-      "/var/folders/p9/37n8_h0j3w136cfjm88xkpcr0000gn/T/ipykernel_33714/3982375415.py:22: SettingWithCopyWarning: \n",
-      "A value is trying to be set on a copy of a slice from a DataFrame\n",
-      "\n",
-      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
-      "  filtered_df.drop(columns=columns_to_drop, inplace=True)\n"
-     ]
-    }
-   ],
-   "source": [
-    "file_path = './LungData/LUAD Clinical Data.xlsx'\n",
-    "progression_column = 'Progression (No: 0, Yes: 1)'\n",
-    "\n",
-    "clinical_variables_input, missing_values_mask = get_clinical_variables_input_and_mask(file_path, progression_column, sheet_name='LUAD_416_Discovery') #, columns=['Sex (Male: 0, Female: 1)'])#, 'Age (<75: 0, ≥75: 1)', 'Predominant histological pattern (Lepidic:1, Papillary: 2, Acinar: 3, Micropapillary: 4, Solid: 5)'])\n",
-    "print(f'Clinical input variables: {clinical_variables_input.shape}, Number of missing values: {missing_values_mask.sum()}')"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 296,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "(404, 12288)"
-      ]
-     },
-     "execution_count": 296,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "resnet_embeddings = get_embeddings_for_subset_from_file('embeddings_all_markers_discovery.csv', channel_indices)\n",
-    "resnet_embeddings.shape"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 297,
+   "execution_count": 454,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1100,7 +1032,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 298,
+   "execution_count": 455,
    "metadata": {},
    "outputs": [
     {
@@ -1109,7 +1041,7 @@
        "((404, 12288), (404, 11), (404,))"
       ]
      },
-     "execution_count": 298,
+     "execution_count": 455,
      "metadata": {},
      "output_type": "execute_result"
     }
@@ -1128,69 +1060,15 @@
     "    # Calculate baseline for this subset\n",
     "    baseline = 1 - np.mean(y)\n",
     "    \n",
-    "    return X_res, X_tab, y, baseline\n",
-    "\n",
-    "X_res, X_tab, y, baseline = remove_rows_with_nan(resnet_embeddings, clinical_variables_input, progression_labels)\n",
-    "X_res.shape, X_tab.shape, y.shape\n",
-    "\n",
-    "resnet_embeddings.shape, clinical_variables_input.shape, progression_labels.shape"
+    "    return X_res, X_tab, y, baseline"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 299,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "((334, 12288), (334, 11), (334,))"
-      ]
-     },
-     "execution_count": 299,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "X_res.shape, X_tab.shape, y.shape"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 300,
-   "metadata": {},
-   "outputs": [
-    {
-     "data": {
-      "text/plain": [
-       "(numpy.ndarray, numpy.ndarray, numpy.ndarray)"
-      ]
-     },
-     "execution_count": 300,
-     "metadata": {},
-     "output_type": "execute_result"
-    }
-   ],
-   "source": [
-    "X_res_resampled, X_tab_resampled, y_resampled = get_oversampled_dataset(X_res, X_tab, y)\n",
-    "X_res_resampled.shape, X_tab_resampled.shape, y_resampled.shape\n",
-    "type(X_res_resampled), type(X_tab_resampled), type(y_resampled)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 329,
+   "execution_count": 456,
    "metadata": {},
    "outputs": [],
    "source": [
-    "\"\"\"\n",
-    "1. ResNet embeddings\n",
-    "2. ResNet embeddings + Raw tabular\n",
-    "3. ResNet embeddings + MPL embedded tabular\n",
-    "\"\"\"\n",
-    "\n",
-    "\n",
     "def train_multimodal_concat(X_res, y, X_tab=None, tabular=None):\n",
     "    \n",
     "    if tabular == \"mlp\":\n",
@@ -1236,10 +1114,11 @@
     "        recall_scores.append(recall)\n",
     "\n",
     "        # Print metrics\n",
-    "        print(f'Confusion matrix: \\n{conf_matrix}')\n",
-    "        print(f'Fold {k+1} test accuracy: {accuracy}')\n",
-    "        print(f'Fold {k+1} test precision: {precision}')\n",
-    "        print(f'Fold {k+1} test recall: {recall}')\n",
+    "        # print(f'Confusion matrix: \\n{conf_matrix}')\n",
+    "        print(f'Fold {k+1}')\n",
+    "        print(f'Test accuracy: {accuracy}')\n",
+    "        print(f'Test precision: {precision}')\n",
+    "        print(f'Test recall: {recall}\\n')\n",
     "\n",
     "    # Print average metrics\n",
     "    print(f'Baseline: {1 - np.mean(progression_labels)}')\n",
@@ -1251,7 +1130,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 322,
+   "execution_count": 457,
    "metadata": {},
    "outputs": [],
    "source": [
@@ -1327,48 +1206,178 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 323,
+   "execution_count": 466,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# channel_names = ['CD20']\n",
+    "channel_names = ['CD68', 'CD163', 'HLA-DR', 'CD11c', 'CD14', 'CD16']\n",
+    "channel_indices = [channel_name_map[name] for name in channel_names]"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 476,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "(404, 12288)"
+      ]
+     },
+     "execution_count": 476,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "resnet_embeddings = get_embeddings_for_subset_from_file('embeddings_all_markers_discovery.csv', channel_indices)\n",
+    "resnet_embeddings.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 467,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "Confusion matrix: \n",
-      "[[56  1]\n",
-      " [ 3 55]]\n",
-      "Fold 1 test accuracy: 0.9652173913043478\n",
-      "Fold 1 test precision: 0.9821428571428571\n",
-      "Fold 1 test recall: 0.9482758620689655\n",
-      "Confusion matrix: \n",
-      "[[58  5]\n",
-      " [ 0 52]]\n",
-      "Fold 2 test accuracy: 0.9565217391304348\n",
-      "Fold 2 test precision: 0.9122807017543859\n",
-      "Fold 2 test recall: 1.0\n",
-      "Confusion matrix: \n",
-      "[[56  1]\n",
-      " [ 0 57]]\n",
-      "Fold 3 test accuracy: 0.9912280701754386\n",
-      "Fold 3 test precision: 0.9827586206896551\n",
-      "Fold 3 test recall: 1.0\n",
-      "Confusion matrix: \n",
-      "[[55  0]\n",
-      " [ 3 56]]\n",
-      "Fold 4 test accuracy: 0.9736842105263158\n",
-      "Fold 4 test precision: 1.0\n",
-      "Fold 4 test recall: 0.9491525423728814\n",
-      "Confusion matrix: \n",
-      "[[53  1]\n",
-      " [ 0 60]]\n",
-      "Fold 5 test accuracy: 0.9912280701754386\n",
-      "Fold 5 test precision: 0.9836065573770492\n",
-      "Fold 5 test recall: 1.0\n",
-      "Baseline: 0.8415841584158416\n",
-      "Average test accuracy: 0.975575896262395\n",
-      "Standard deviation of accuracies: 0.013884661323256872\n",
-      "Average test precision: 0.9721577473927894\n",
-      "Average test recall: 0.9794856808883694\n"
+      "Clinical input variables: (404, 11), Number of missing values: 70\n"
+     ]
+    },
+    {
+     "name": "stderr",
+     "output_type": "stream",
+     "text": [
+      "/var/folders/p9/37n8_h0j3w136cfjm88xkpcr0000gn/T/ipykernel_33714/3982375415.py:12: SettingWithCopyWarning: \n",
+      "A value is trying to be set on a copy of a slice from a DataFrame.\n",
+      "Try using .loc[row_indexer,col_indexer] = value instead\n",
+      "\n",
+      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
+      "  filtered_df[progression_column] = filtered_df[progression_column].astype(int)\n",
+      "/var/folders/p9/37n8_h0j3w136cfjm88xkpcr0000gn/T/ipykernel_33714/3982375415.py:22: SettingWithCopyWarning: \n",
+      "A value is trying to be set on a copy of a slice from a DataFrame\n",
+      "\n",
+      "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
+      "  filtered_df.drop(columns=columns_to_drop, inplace=True)\n"
+     ]
+    }
+   ],
+   "source": [
+    "file_path = './LungData/LUAD Clinical Data.xlsx'\n",
+    "progression_column = 'Progression (No: 0, Yes: 1)'\n",
+    "\n",
+    "clinical_variables_input, missing_values_mask = get_clinical_variables_input_and_mask(file_path, progression_column, sheet_name='LUAD_416_Discovery') #, columns=['Sex (Male: 0, Female: 1)'])#, 'Age (<75: 0, ≥75: 1)', 'Predominant histological pattern (Lepidic:1, Papillary: 2, Acinar: 3, Micropapillary: 4, Solid: 5)'])\n",
+    "print(f'Clinical input variables: {clinical_variables_input.shape}, Number of missing values: {missing_values_mask.sum()}')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 475,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "((334, 12288), (334, 11), (334,))"
+      ]
+     },
+     "execution_count": 475,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "X_res, X_tab, y, baseline = remove_rows_with_nan(resnet_embeddings, clinical_variables_input, progression_labels)\n",
+    "X_res.shape, X_tab.shape, y.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 472,
+   "metadata": {},
+   "outputs": [
+    {
+     "data": {
+      "text/plain": [
+       "((572, 12288), (572, 11), (572,))"
+      ]
+     },
+     "execution_count": 472,
+     "metadata": {},
+     "output_type": "execute_result"
+    }
+   ],
+   "source": [
+    "X_res_resampled, X_tab_resampled, y_resampled = get_oversampled_dataset(X_res, X_tab, y)\n",
+    "X_res_resampled.shape, X_tab_resampled.shape, y_resampled.shape"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 504,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Baseline for Sex (Male: 0, Female: 1): 84.16\n",
+      "Baseline for Age (<75: 0, ≥75: 1): 84.16\n",
+      "Baseline for BMI (<30: 0, ≥30: 1): 84.16\n",
+      "Baseline for Smoking Status (Smoker: 0, Non-smoker:1): 84.37\n",
+      "Baseline for Pack Years (1-30: 0, ≥30: 1): 85.37\n",
+      "Baseline for Stage (I-II: 0, III-IV:1): 84.37\n",
+      "Baseline for Predominant histological pattern (Lepidic:1, Papillary: 2, Acinar: 3, Micropapillary: 4, Solid: 5): 84.16\n"
+     ]
+    }
+   ],
+   "source": [
+    "clinical_variables_input_column_names = ['Sex (Male: 0, Female: 1)','Age (<75: 0, ≥75: 1)','BMI (<30: 0, ≥30: 1)','Smoking Status (Smoker: 0, Non-smoker:1)','Pack Years (1-30: 0, ≥30: 1)','Stage (I-II: 0, III-IV:1)','Predominant histological pattern (Lepidic:1, Papillary: 2, Acinar: 3, Micropapillary: 4, Solid: 5)']\n",
+    "\n",
+    "for clinical_variable in clinical_variables_input_column_names:\n",
+    "    # print(clinical_variable)\n",
+    "    clinical_variables_input, missing_values_mask = get_clinical_variables_input_and_mask(file_path, progression_column, sheet_name='LUAD_416_Discovery', columns=[clinical_variable])\n",
+    "    X_res, X_tab, y, baseline = remove_rows_with_nan(resnet_embeddings, clinical_variables_input, progression_labels)\n",
+    "    baseline = (1 - np.mean(y))*100\n",
+    "    print(f\"Baseline for {clinical_variable}: {baseline:.2f}\")\n",
+    "    # X_res_resampled, X_tab_resampled, y_resampled = get_oversampled_dataset(X_res, X_tab, y)\n",
+    "\n",
+    "    # train_multimodal_concat(X_res_resampled, y_resampled, X_tab_resampled, tabular=\"raw\")"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 474,
+   "metadata": {},
+   "outputs": [
+    {
+     "ename": "KeyboardInterrupt",
+     "evalue": "",
+     "output_type": "error",
+     "traceback": [
+      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
+      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
+      "Cell \u001b[0;32mIn[474], line 3\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[39m# TODO fix ValueError: X has 11 features, but MiniBatchSparsePCA is expecting 12288 features as input.\u001b[39;00m\n\u001b[1;32m      2\u001b[0m \u001b[39m# TODO Fix train svm pca model function to split train and test\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m train_multimodal_aggregate(X_res_resampled, X_tab_resampled, y_resampled)\n",
+      "Cell \u001b[0;32mIn[457], line 22\u001b[0m, in \u001b[0;36mtrain_multimodal_aggregate\u001b[0;34m(X_res, X_tab, y)\u001b[0m\n\u001b[1;32m     19\u001b[0m y_train, y_test \u001b[39m=\u001b[39m y[train_index], y[test_index]\n\u001b[1;32m     21\u001b[0m \u001b[39m# Train PCA and SVM on images\u001b[39;00m\n\u001b[0;32m---> 22\u001b[0m pca_model, svm_model \u001b[39m=\u001b[39m train_PCA_SVM_clf(X_res_train, y_train)\n\u001b[1;32m     23\u001b[0m \u001b[39m# Get predicted probabilities from the SVM model\u001b[39;00m\n\u001b[1;32m     24\u001b[0m predicted_probabilities_image_train \u001b[39m=\u001b[39m svm_model\u001b[39m.\u001b[39mpredict_proba(pca_model\u001b[39m.\u001b[39mtransform(X_res_train))\n",
+      "Cell \u001b[0;32mIn[448], line 6\u001b[0m, in \u001b[0;36mtrain_PCA_SVM_clf\u001b[0;34m(X_train, y_train)\u001b[0m\n\u001b[1;32m      3\u001b[0m pca_model \u001b[39m=\u001b[39m MiniBatchSparsePCA(n_components\u001b[39m=\u001b[39m\u001b[39m9\u001b[39m, batch_size\u001b[39m=\u001b[39m\u001b[39m500\u001b[39m,random_state\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m)\n\u001b[1;32m      5\u001b[0m \u001b[39m# Train PCA model\u001b[39;00m\n\u001b[0;32m----> 6\u001b[0m pca_model\u001b[39m.\u001b[39;49mfit(X_train)\n\u001b[1;32m      8\u001b[0m \u001b[39m# Get PCA embeddings of training data\u001b[39;00m\n\u001b[1;32m      9\u001b[0m X_tr_em \u001b[39m=\u001b[39m pca_model\u001b[39m.\u001b[39mtransform(X_train)\n",
+      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/decomposition/_sparse_pca.py:85\u001b[0m, in \u001b[0;36m_BaseSparsePCA.fit\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m     82\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m     83\u001b[0m     n_components \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_components\n\u001b[0;32m---> 85\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_fit(X, n_components, random_state)\n",
+      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/decomposition/_sparse_pca.py:527\u001b[0m, in \u001b[0;36mMiniBatchSparsePCA._fit\u001b[0;34m(self, X, n_components, random_state)\u001b[0m\n\u001b[1;32m    524\u001b[0m \u001b[39m\u001b[39m\u001b[39m\"\"\"Specialized `fit` for MiniBatchSparsePCA.\"\"\"\u001b[39;00m\n\u001b[1;32m    526\u001b[0m transform_algorithm \u001b[39m=\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mlasso_\u001b[39m\u001b[39m\"\u001b[39m \u001b[39m+\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmethod\n\u001b[0;32m--> 527\u001b[0m est \u001b[39m=\u001b[39m MiniBatchDictionaryLearning(\n\u001b[1;32m    528\u001b[0m     n_components\u001b[39m=\u001b[39;49mn_components,\n\u001b[1;32m    529\u001b[0m     alpha\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49malpha,\n\u001b[1;32m    530\u001b[0m     n_iter\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mn_iter,\n\u001b[1;32m    531\u001b[0m     max_iter\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmax_iter,\n\u001b[1;32m    532\u001b[0m     dict_init\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m    533\u001b[0m     batch_size\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mbatch_size,\n\u001b[1;32m    534\u001b[0m     shuffle\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mshuffle,\n\u001b[1;32m    535\u001b[0m     n_jobs\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mn_jobs,\n\u001b[1;32m    536\u001b[0m     fit_algorithm\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmethod,\n\u001b[1;32m    537\u001b[0m     random_state\u001b[39m=\u001b[39;49mrandom_state,\n\u001b[1;32m    538\u001b[0m     transform_algorithm\u001b[39m=\u001b[39;49mtransform_algorithm,\n\u001b[1;32m    539\u001b[0m     transform_alpha\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49malpha,\n\u001b[1;32m    540\u001b[0m     verbose\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mverbose,\n\u001b[1;32m    541\u001b[0m     callback\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcallback,\n\u001b[1;32m    542\u001b[0m     tol\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtol,\n\u001b[1;32m    543\u001b[0m     max_no_improvement\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmax_no_improvement,\n\u001b[1;32m    544\u001b[0m )\u001b[39m.\u001b[39;49mfit(X\u001b[39m.\u001b[39;49mT)\n\u001b[1;32m    546\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcomponents_, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_iter_ \u001b[39m=\u001b[39m est\u001b[39m.\u001b[39mtransform(X\u001b[39m.\u001b[39mT)\u001b[39m.\u001b[39mT, est\u001b[39m.\u001b[39mn_iter_\n\u001b[1;32m    548\u001b[0m components_norm \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39mlinalg\u001b[39m.\u001b[39mnorm(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcomponents_, axis\u001b[39m=\u001b[39m\u001b[39m1\u001b[39m)[:, np\u001b[39m.\u001b[39mnewaxis]\n",
+      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/decomposition/_dict_learning.py:2377\u001b[0m, in \u001b[0;36mMiniBatchDictionaryLearning.fit\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m   2374\u001b[0m batches \u001b[39m=\u001b[39m itertools\u001b[39m.\u001b[39mcycle(batches)\n\u001b[1;32m   2376\u001b[0m \u001b[39mfor\u001b[39;00m i, batch \u001b[39min\u001b[39;00m \u001b[39mzip\u001b[39m(\u001b[39mrange\u001b[39m(n_iter), batches):\n\u001b[0;32m-> 2377\u001b[0m     \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_minibatch_step(X_train[batch], dictionary, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_random_state, i)\n\u001b[1;32m   2379\u001b[0m     trigger_verbose \u001b[39m=\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mverbose \u001b[39mand\u001b[39;00m i \u001b[39m%\u001b[39m ceil(\u001b[39m100.0\u001b[39m \u001b[39m/\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mverbose) \u001b[39m==\u001b[39m \u001b[39m0\u001b[39m\n\u001b[1;32m   2380\u001b[0m     \u001b[39mif\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mverbose \u001b[39m>\u001b[39m \u001b[39m10\u001b[39m \u001b[39mor\u001b[39;00m trigger_verbose:\n",
+      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/decomposition/_dict_learning.py:2166\u001b[0m, in \u001b[0;36mMiniBatchDictionaryLearning._minibatch_step\u001b[0;34m(self, X, dictionary, random_state, step)\u001b[0m\n\u001b[1;32m   2163\u001b[0m batch_size \u001b[39m=\u001b[39m X\u001b[39m.\u001b[39mshape[\u001b[39m0\u001b[39m]\n\u001b[1;32m   2165\u001b[0m \u001b[39m# Compute code for this batch\u001b[39;00m\n\u001b[0;32m-> 2166\u001b[0m code \u001b[39m=\u001b[39m sparse_encode(\n\u001b[1;32m   2167\u001b[0m     X,\n\u001b[1;32m   2168\u001b[0m     dictionary,\n\u001b[1;32m   2169\u001b[0m     algorithm\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_fit_algorithm,\n\u001b[1;32m   2170\u001b[0m     alpha\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49malpha,\n\u001b[1;32m   2171\u001b[0m     n_jobs\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mn_jobs,\n\u001b[1;32m   2172\u001b[0m     check_input\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m,\n\u001b[1;32m   2173\u001b[0m     positive\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mpositive_code,\n\u001b[1;32m   2174\u001b[0m     max_iter\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mtransform_max_iter,\n\u001b[1;32m   2175\u001b[0m     verbose\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mverbose,\n\u001b[1;32m   2176\u001b[0m )\n\u001b[1;32m   2178\u001b[0m batch_cost \u001b[39m=\u001b[39m (\n\u001b[1;32m   2179\u001b[0m     \u001b[39m0.5\u001b[39m \u001b[39m*\u001b[39m ((X \u001b[39m-\u001b[39m code \u001b[39m@\u001b[39m dictionary) \u001b[39m*\u001b[39m\u001b[39m*\u001b[39m \u001b[39m2\u001b[39m)\u001b[39m.\u001b[39msum()\n\u001b[1;32m   2180\u001b[0m     \u001b[39m+\u001b[39m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39malpha \u001b[39m*\u001b[39m np\u001b[39m.\u001b[39msum(np\u001b[39m.\u001b[39mabs(code))\n\u001b[1;32m   2181\u001b[0m ) \u001b[39m/\u001b[39m batch_size\n\u001b[1;32m   2183\u001b[0m \u001b[39m# Update inner stats\u001b[39;00m\n",
+      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/decomposition/_dict_learning.py:378\u001b[0m, in \u001b[0;36msparse_encode\u001b[0;34m(X, dictionary, gram, cov, algorithm, n_nonzero_coefs, alpha, copy_cov, init, max_iter, n_jobs, check_input, verbose, positive)\u001b[0m\n\u001b[1;32m    375\u001b[0m         regularization \u001b[39m=\u001b[39m \u001b[39m1.0\u001b[39m\n\u001b[1;32m    377\u001b[0m \u001b[39mif\u001b[39;00m effective_n_jobs(n_jobs) \u001b[39m==\u001b[39m \u001b[39m1\u001b[39m \u001b[39mor\u001b[39;00m algorithm \u001b[39m==\u001b[39m \u001b[39m\"\u001b[39m\u001b[39mthreshold\u001b[39m\u001b[39m\"\u001b[39m:\n\u001b[0;32m--> 378\u001b[0m     code \u001b[39m=\u001b[39m _sparse_encode(\n\u001b[1;32m    379\u001b[0m         X,\n\u001b[1;32m    380\u001b[0m         dictionary,\n\u001b[1;32m    381\u001b[0m         gram,\n\u001b[1;32m    382\u001b[0m         cov\u001b[39m=\u001b[39;49mcov,\n\u001b[1;32m    383\u001b[0m         algorithm\u001b[39m=\u001b[39;49malgorithm,\n\u001b[1;32m    384\u001b[0m         regularization\u001b[39m=\u001b[39;49mregularization,\n\u001b[1;32m    385\u001b[0m         copy_cov\u001b[39m=\u001b[39;49mcopy_cov,\n\u001b[1;32m    386\u001b[0m         init\u001b[39m=\u001b[39;49minit,\n\u001b[1;32m    387\u001b[0m         max_iter\u001b[39m=\u001b[39;49mmax_iter,\n\u001b[1;32m    388\u001b[0m         check_input\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m,\n\u001b[1;32m    389\u001b[0m         verbose\u001b[39m=\u001b[39;49mverbose,\n\u001b[1;32m    390\u001b[0m         positive\u001b[39m=\u001b[39;49mpositive,\n\u001b[1;32m    391\u001b[0m     )\n\u001b[1;32m    392\u001b[0m     \u001b[39mreturn\u001b[39;00m code\n\u001b[1;32m    394\u001b[0m \u001b[39m# Enter parallel code block\u001b[39;00m\n",
+      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/decomposition/_dict_learning.py:156\u001b[0m, in \u001b[0;36m_sparse_encode\u001b[0;34m(X, dictionary, gram, cov, algorithm, regularization, copy_cov, init, max_iter, check_input, verbose, positive)\u001b[0m\n\u001b[1;32m    145\u001b[0m     \u001b[39m# Not passing in verbose=max(0, verbose-1) because Lars.fit already\u001b[39;00m\n\u001b[1;32m    146\u001b[0m     \u001b[39m# corrects the verbosity level.\u001b[39;00m\n\u001b[1;32m    147\u001b[0m     lasso_lars \u001b[39m=\u001b[39m LassoLars(\n\u001b[1;32m    148\u001b[0m         alpha\u001b[39m=\u001b[39malpha,\n\u001b[1;32m    149\u001b[0m         fit_intercept\u001b[39m=\u001b[39m\u001b[39mFalse\u001b[39;00m,\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m    154\u001b[0m         max_iter\u001b[39m=\u001b[39mmax_iter,\n\u001b[1;32m    155\u001b[0m     )\n\u001b[0;32m--> 156\u001b[0m     lasso_lars\u001b[39m.\u001b[39;49mfit(dictionary\u001b[39m.\u001b[39;49mT, X\u001b[39m.\u001b[39;49mT, Xy\u001b[39m=\u001b[39;49mcov)\n\u001b[1;32m    157\u001b[0m     new_code \u001b[39m=\u001b[39m lasso_lars\u001b[39m.\u001b[39mcoef_\n\u001b[1;32m    158\u001b[0m \u001b[39mfinally\u001b[39;00m:\n",
+      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/linear_model/_least_angle.py:1144\u001b[0m, in \u001b[0;36mLars.fit\u001b[0;34m(self, X, y, Xy)\u001b[0m\n\u001b[1;32m   1141\u001b[0m     noise \u001b[39m=\u001b[39m rng\u001b[39m.\u001b[39muniform(high\u001b[39m=\u001b[39m\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mjitter, size\u001b[39m=\u001b[39m\u001b[39mlen\u001b[39m(y))\n\u001b[1;32m   1142\u001b[0m     y \u001b[39m=\u001b[39m y \u001b[39m+\u001b[39m noise\n\u001b[0;32m-> 1144\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_fit(\n\u001b[1;32m   1145\u001b[0m     X,\n\u001b[1;32m   1146\u001b[0m     y,\n\u001b[1;32m   1147\u001b[0m     max_iter\u001b[39m=\u001b[39;49mmax_iter,\n\u001b[1;32m   1148\u001b[0m     alpha\u001b[39m=\u001b[39;49malpha,\n\u001b[1;32m   1149\u001b[0m     fit_path\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mfit_path,\n\u001b[1;32m   1150\u001b[0m     normalize\u001b[39m=\u001b[39;49m_normalize,\n\u001b[1;32m   1151\u001b[0m     Xy\u001b[39m=\u001b[39;49mXy,\n\u001b[1;32m   1152\u001b[0m )\n\u001b[1;32m   1154\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\n",
+      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/linear_model/_least_angle.py:1077\u001b[0m, in \u001b[0;36mLars._fit\u001b[0;34m(self, X, y, max_iter, alpha, fit_path, normalize, Xy)\u001b[0m\n\u001b[1;32m   1075\u001b[0m \u001b[39mfor\u001b[39;00m k \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(n_targets):\n\u001b[1;32m   1076\u001b[0m     this_Xy \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m \u001b[39mif\u001b[39;00m Xy \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39melse\u001b[39;00m Xy[:, k]\n\u001b[0;32m-> 1077\u001b[0m     alphas, _, \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mcoef_[k], n_iter_ \u001b[39m=\u001b[39m lars_path(\n\u001b[1;32m   1078\u001b[0m         X,\n\u001b[1;32m   1079\u001b[0m         y[:, k],\n\u001b[1;32m   1080\u001b[0m         Gram\u001b[39m=\u001b[39;49mGram,\n\u001b[1;32m   1081\u001b[0m         Xy\u001b[39m=\u001b[39;49mthis_Xy,\n\u001b[1;32m   1082\u001b[0m         copy_X\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mcopy_X,\n\u001b[1;32m   1083\u001b[0m         copy_Gram\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m   1084\u001b[0m         alpha_min\u001b[39m=\u001b[39;49malpha,\n\u001b[1;32m   1085\u001b[0m         method\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mmethod,\n\u001b[1;32m   1086\u001b[0m         verbose\u001b[39m=\u001b[39;49m\u001b[39mmax\u001b[39;49m(\u001b[39m0\u001b[39;49m, \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mverbose \u001b[39m-\u001b[39;49m \u001b[39m1\u001b[39;49m),\n\u001b[1;32m   1087\u001b[0m         max_iter\u001b[39m=\u001b[39;49mmax_iter,\n\u001b[1;32m   1088\u001b[0m         eps\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49meps,\n\u001b[1;32m   1089\u001b[0m         return_path\u001b[39m=\u001b[39;49m\u001b[39mFalse\u001b[39;49;00m,\n\u001b[1;32m   1090\u001b[0m         return_n_iter\u001b[39m=\u001b[39;49m\u001b[39mTrue\u001b[39;49;00m,\n\u001b[1;32m   1091\u001b[0m         positive\u001b[39m=\u001b[39;49m\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mpositive,\n\u001b[1;32m   1092\u001b[0m     )\n\u001b[1;32m   1093\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39malphas_\u001b[39m.\u001b[39mappend(alphas)\n\u001b[1;32m   1094\u001b[0m     \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mn_iter_\u001b[39m.\u001b[39mappend(n_iter_)\n",
+      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/linear_model/_least_angle.py:170\u001b[0m, in \u001b[0;36mlars_path\u001b[0;34m(X, y, Xy, Gram, max_iter, alpha_min, method, copy_X, eps, copy_Gram, verbose, return_path, return_n_iter, positive)\u001b[0m\n\u001b[1;32m    165\u001b[0m \u001b[39mif\u001b[39;00m X \u001b[39mis\u001b[39;00m \u001b[39mNone\u001b[39;00m \u001b[39mand\u001b[39;00m Gram \u001b[39mis\u001b[39;00m \u001b[39mnot\u001b[39;00m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m    166\u001b[0m     \u001b[39mraise\u001b[39;00m \u001b[39mValueError\u001b[39;00m(\n\u001b[1;32m    167\u001b[0m         \u001b[39m\"\u001b[39m\u001b[39mX cannot be None if Gram is not None\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m    168\u001b[0m         \u001b[39m\"\u001b[39m\u001b[39mUse lars_path_gram to avoid passing X and y.\u001b[39m\u001b[39m\"\u001b[39m\n\u001b[1;32m    169\u001b[0m     )\n\u001b[0;32m--> 170\u001b[0m \u001b[39mreturn\u001b[39;00m _lars_path_solver(\n\u001b[1;32m    171\u001b[0m     X\u001b[39m=\u001b[39;49mX,\n\u001b[1;32m    172\u001b[0m     y\u001b[39m=\u001b[39;49my,\n\u001b[1;32m    173\u001b[0m     Xy\u001b[39m=\u001b[39;49mXy,\n\u001b[1;32m    174\u001b[0m     Gram\u001b[39m=\u001b[39;49mGram,\n\u001b[1;32m    175\u001b[0m     n_samples\u001b[39m=\u001b[39;49m\u001b[39mNone\u001b[39;49;00m,\n\u001b[1;32m    176\u001b[0m     max_iter\u001b[39m=\u001b[39;49mmax_iter,\n\u001b[1;32m    177\u001b[0m     alpha_min\u001b[39m=\u001b[39;49malpha_min,\n\u001b[1;32m    178\u001b[0m     method\u001b[39m=\u001b[39;49mmethod,\n\u001b[1;32m    179\u001b[0m     copy_X\u001b[39m=\u001b[39;49mcopy_X,\n\u001b[1;32m    180\u001b[0m     eps\u001b[39m=\u001b[39;49meps,\n\u001b[1;32m    181\u001b[0m     copy_Gram\u001b[39m=\u001b[39;49mcopy_Gram,\n\u001b[1;32m    182\u001b[0m     verbose\u001b[39m=\u001b[39;49mverbose,\n\u001b[1;32m    183\u001b[0m     return_path\u001b[39m=\u001b[39;49mreturn_path,\n\u001b[1;32m    184\u001b[0m     return_n_iter\u001b[39m=\u001b[39;49mreturn_n_iter,\n\u001b[1;32m    185\u001b[0m     positive\u001b[39m=\u001b[39;49mpositive,\n\u001b[1;32m    186\u001b[0m )\n",
+      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/sklearn/linear_model/_least_angle.py:729\u001b[0m, in \u001b[0;36m_lars_path_solver\u001b[0;34m(X, y, Xy, Gram, n_samples, max_iter, alpha_min, method, copy_X, eps, copy_Gram, verbose, return_path, return_n_iter, positive)\u001b[0m\n\u001b[1;32m    725\u001b[0m     corr_eq_dir \u001b[39m=\u001b[39m np\u001b[39m.\u001b[39mdot(Gram[:n_active, n_active:]\u001b[39m.\u001b[39mT, least_squares)\n\u001b[1;32m    727\u001b[0m \u001b[39m# Explicit rounding can be necessary to avoid `np.argmax(Cov)` yielding\u001b[39;00m\n\u001b[1;32m    728\u001b[0m \u001b[39m# unstable results because of rounding errors.\u001b[39;00m\n\u001b[0;32m--> 729\u001b[0m np\u001b[39m.\u001b[39;49maround(corr_eq_dir, decimals\u001b[39m=\u001b[39;49mcov_precision, out\u001b[39m=\u001b[39;49mcorr_eq_dir)\n\u001b[1;32m    731\u001b[0m g1 \u001b[39m=\u001b[39m arrayfuncs\u001b[39m.\u001b[39mmin_pos((C \u001b[39m-\u001b[39m Cov) \u001b[39m/\u001b[39m (AA \u001b[39m-\u001b[39m corr_eq_dir \u001b[39m+\u001b[39m tiny32))\n\u001b[1;32m    732\u001b[0m \u001b[39mif\u001b[39;00m positive:\n",
+      "File \u001b[0;32m<__array_function__ internals>:200\u001b[0m, in \u001b[0;36maround\u001b[0;34m(*args, **kwargs)\u001b[0m\n",
+      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/numpy/core/fromnumeric.py:3337\u001b[0m, in \u001b[0;36maround\u001b[0;34m(a, decimals, out)\u001b[0m\n\u001b[1;32m   3245\u001b[0m \u001b[39m@array_function_dispatch\u001b[39m(_around_dispatcher)\n\u001b[1;32m   3246\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39maround\u001b[39m(a, decimals\u001b[39m=\u001b[39m\u001b[39m0\u001b[39m, out\u001b[39m=\u001b[39m\u001b[39mNone\u001b[39;00m):\n\u001b[1;32m   3247\u001b[0m \u001b[39m    \u001b[39m\u001b[39m\"\"\"\u001b[39;00m\n\u001b[1;32m   3248\u001b[0m \u001b[39m    Evenly round to the given number of decimals.\u001b[39;00m\n\u001b[1;32m   3249\u001b[0m \n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m   3335\u001b[0m \n\u001b[1;32m   3336\u001b[0m \u001b[39m    \"\"\"\u001b[39;00m\n\u001b[0;32m-> 3337\u001b[0m     \u001b[39mreturn\u001b[39;00m _wrapfunc(a, \u001b[39m'\u001b[39;49m\u001b[39mround\u001b[39;49m\u001b[39m'\u001b[39;49m, decimals\u001b[39m=\u001b[39;49mdecimals, out\u001b[39m=\u001b[39;49mout)\n",
+      "File \u001b[0;32m/Library/Frameworks/Python.framework/Versions/3.9/lib/python3.9/site-packages/numpy/core/fromnumeric.py:57\u001b[0m, in \u001b[0;36m_wrapfunc\u001b[0;34m(obj, method, *args, **kwds)\u001b[0m\n\u001b[1;32m     54\u001b[0m     \u001b[39mreturn\u001b[39;00m _wrapit(obj, method, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwds)\n\u001b[1;32m     56\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m---> 57\u001b[0m     \u001b[39mreturn\u001b[39;00m bound(\u001b[39m*\u001b[39;49margs, \u001b[39m*\u001b[39;49m\u001b[39m*\u001b[39;49mkwds)\n\u001b[1;32m     58\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mTypeError\u001b[39;00m:\n\u001b[1;32m     59\u001b[0m     \u001b[39m# A TypeError occurs if the object does have such a method in its\u001b[39;00m\n\u001b[1;32m     60\u001b[0m     \u001b[39m# class, but its signature is not identical to that of NumPy's. This\u001b[39;00m\n\u001b[0;32m   (...)\u001b[0m\n\u001b[1;32m     64\u001b[0m     \u001b[39m# Call _wrapit from within the except clause to ensure a potential\u001b[39;00m\n\u001b[1;32m     65\u001b[0m     \u001b[39m# exception has a traceback chain.\u001b[39;00m\n\u001b[1;32m     66\u001b[0m     \u001b[39mreturn\u001b[39;00m _wrapit(obj, method, \u001b[39m*\u001b[39margs, \u001b[39m*\u001b[39m\u001b[39m*\u001b[39mkwds)\n",
+      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
      ]
     }
    ],
@@ -1380,7 +1389,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 330,
+   "execution_count": null,
    "metadata": {},
    "outputs": [
     {
@@ -1469,7 +1478,7 @@
   },
   {
    "cell_type": "code",
-   "execution_count": 367,
+   "execution_count": null,
    "metadata": {},
    "outputs": [
     {
@@ -1536,6 +1545,168 @@
     "print(f'Average test precision: {np.mean(precision_scores)}')\n",
     "print(f'Average test recall: {np.mean(recall_scores)}')"
    ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 499,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Header: Sex (Male: 0, Female: 1)\n",
+      "Accuracies: ['94.85', '93.38', '94.12', '94.85', '93.38']\n",
+      "Average Accuracy: 94.12\n",
+      "Standard Deviation of Accuracies: 0.66\n",
+      "Average Precision: 90.29\n",
+      "Average Recall: 98.84\n",
+      "\n",
+      "Latex string\n",
+      "& 94.85 & 93.38 & 94.12 & 94.85 & 93.38 & 94.12 & 0.66 & 90.29 & 98.84\n",
+      "\n",
+      "Header: Age (<75: 0, ≥75: 1)\n",
+      "Accuracies: ['94.85', '93.38', '94.12', '95.59', '93.38']\n",
+      "Average Accuracy: 94.26\n",
+      "Standard Deviation of Accuracies: 0.86\n",
+      "Average Precision: 90.53\n",
+      "Average Recall: 98.84\n",
+      "\n",
+      "Latex string\n",
+      "& 94.85 & 93.38 & 94.12 & 95.59 & 93.38 & 94.26 & 0.86 & 90.53 & 98.84\n",
+      "\n",
+      "Header: BMI (<30: 0, ≥30: 1)\n",
+      "Accuracies: ['94.85', '93.38', '94.12', '94.85', '93.38']\n",
+      "Average Accuracy: 94.12\n",
+      "Standard Deviation of Accuracies: 0.66\n",
+      "Average Precision: 90.29\n",
+      "Average Recall: 98.84\n",
+      "\n",
+      "Latex string\n",
+      "& 94.85 & 93.38 & 94.12 & 94.85 & 93.38 & 94.12 & 0.66 & 90.29 & 98.84\n",
+      "\n",
+      "Header: Smoking Status (Smoker: 0, Non-smoker:1)\n",
+      "Accuracies: ['91.91', '96.32', '95.59', '97.06', '97.06']\n",
+      "Average Accuracy: 95.59\n",
+      "Standard Deviation of Accuracies: 1.92\n",
+      "Average Precision: 92.54\n",
+      "Average Recall: 99.38\n",
+      "\n",
+      "Latex string\n",
+      "& 91.91 & 96.32 & 95.59 & 97.06 & 97.06 & 95.59 & 1.92 & 92.54 & 99.38\n",
+      "\n",
+      "Header: Pack Years (1-30: 0, ≥30: 1)\n",
+      "Accuracies: ['95.65', '93.91', '94.74', '93.86', '98.25']\n",
+      "Average Accuracy: 95.28\n",
+      "Standard Deviation of Accuracies: 1.62\n",
+      "Average Precision: 93.01\n",
+      "Average Recall: 97.95\n",
+      "\n",
+      "Latex string\n",
+      "& 95.65 & 93.91 & 94.74 & 93.86 & 98.25 & 95.28 & 1.62 & 93.01 & 97.95\n",
+      "\n",
+      "Header: Stage (I-II: 0, III-IV:1)\n",
+      "Accuracies: ['93.38', '96.32', '94.12', '97.79', '96.32']\n",
+      "Average Accuracy: 95.59\n",
+      "Standard Deviation of Accuracies: 1.61\n",
+      "Average Precision: 92.63\n",
+      "Average Recall: 99.09\n",
+      "\n",
+      "Latex string\n",
+      "& 93.38 & 96.32 & 94.12 & 97.79 & 96.32 & 95.59 & 1.61 & 92.63 & 99.09\n",
+      "\n",
+      "Header: Predominant histological pattern (Lepidic:1, Papillary: 2, Acinar: 3, Micropapillary: 4, Solid: 5)\n",
+      "Accuracies: ['97.79', '94.85', '94.12', '97.06', '94.85']\n",
+      "Average Accuracy: 95.74\n",
+      "Standard Deviation of Accuracies: 1.43\n",
+      "Average Precision: 93.06\n",
+      "Average Recall: 98.84\n",
+      "\n",
+      "Latex string\n",
+      "& 97.79 & 94.85 & 94.12 & 97.06 & 94.85 & 95.74 & 1.43 & 93.06 & 98.84\n",
+      "\n"
+     ]
+    }
+   ],
+   "source": [
+    "import numpy as np\n",
+    "\n",
+    "# Define the headers\n",
+    "headers = [\n",
+    "    'Sex (Male: 0, Female: 1)',\n",
+    "    'Age (<75: 0, ≥75: 1)',\n",
+    "    'BMI (<30: 0, ≥30: 1)',\n",
+    "    'Smoking Status (Smoker: 0, Non-smoker:1)',\n",
+    "    'Pack Years (1-30: 0, ≥30: 1)',\n",
+    "    'Stage (I-II: 0, III-IV:1)',\n",
+    "    'Predominant histological pattern (Lepidic:1, Papillary: 2, Acinar: 3, Micropapillary: 4, Solid: 5)'\n",
+    "]\n",
+    "\n",
+    "# Initialize dictionaries to store results\n",
+    "results = {header: {'accuracies': [], 'precisions': [], 'recalls': []} for header in headers}\n",
+    "\n",
+    "# Read the file\n",
+    "with open('./results_each_clin_var.txt', 'r') as file:\n",
+    "    lines = file.readlines()\n",
+    "\n",
+    "# Parse the data\n",
+    "current_header = None\n",
+    "for line in lines:\n",
+    "    line = line.strip()\n",
+    "    if line in headers:\n",
+    "        current_header = line\n",
+    "    elif 'Test accuracy' in line:\n",
+    "        accuracy = float(line.split(': ')[1]) * 100\n",
+    "        results[current_header]['accuracies'].append(accuracy)\n",
+    "    elif 'Test precision' in line:\n",
+    "        precision = float(line.split(': ')[1]) * 100\n",
+    "        results[current_header]['precisions'].append(precision)\n",
+    "    elif 'Test recall' in line:\n",
+    "        recall = float(line.split(': ')[1]) * 100\n",
+    "        results[current_header]['recalls'].append(recall)\n",
+    "\n",
+    "# Calculate statistics\n",
+    "for header in headers:\n",
+    "    accuracies = results[header]['accuracies']\n",
+    "    precisions = results[header]['precisions']\n",
+    "    recalls = results[header]['recalls']\n",
+    "    \n",
+    "    avg_accuracy = np.mean(accuracies)\n",
+    "    std_accuracy = np.std(accuracies)\n",
+    "    avg_precision = np.mean(precisions)\n",
+    "    avg_recall = np.mean(recalls)\n",
+    "    \n",
+    "    results[header]['avg_accuracy'] = avg_accuracy\n",
+    "    results[header]['std_accuracy'] = std_accuracy\n",
+    "    results[header]['avg_precision'] = avg_precision\n",
+    "    results[header]['avg_recall'] = avg_recall\n",
+    "\n",
+    "    formatted_accuracies = [f\"{acc:.2f}\" for acc in accuracies]\n",
+    "\n",
+    "    print(f\"Header: {header}\")\n",
+    "    print(f\"Accuracies: {formatted_accuracies}\")\n",
+    "    print(f\"Average Accuracy: {avg_accuracy:.2f}\")\n",
+    "    print(f\"Standard Deviation of Accuracies: {std_accuracy:.2f}\")\n",
+    "    print(f\"Average Precision: {avg_precision:.2f}\")\n",
+    "    print(f\"Average Recall: {avg_recall:.2f}\")\n",
+    "    print()\n",
+    "\n",
+    "    latex_string = \"&\"\n",
+    "    \n",
+    "    for acc in accuracies:\n",
+    "        latex_string += f\" {acc:.2f} &\"\n",
+    "\n",
+    "    latex_string += f\" {avg_accuracy:.2f} &\"\n",
+    "    latex_string += f\" {std_accuracy:.2f} &\"\n",
+    "    latex_string += f\" {avg_precision:.2f} &\"\n",
+    "    latex_string += f\" {avg_recall:.2f}\"\n",
+    "\n",
+    "    print('Latex string')\n",
+    "    print(latex_string)\n",
+    "    print()\n",
+    "\n",
+    "# Note: Results are in percent as required.\n"
+   ]
   }
  ],
  "metadata": {
diff --git a/results_each_clin_var.txt b/results_each_clin_var.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1e985bbecb504e7502cc571d818ec3ec97537929
--- /dev/null
+++ b/results_each_clin_var.txt
@@ -0,0 +1,223 @@
+Sex (Male: 0, Female: 1)
+Fold 1
+Test accuracy: 0.9485294117647058
+Test precision: 0.9066666666666666
+Test recall: 1.0
+
+Fold 2
+Test accuracy: 0.9338235294117647
+Test precision: 0.8873239436619719
+Test recall: 0.984375
+
+Fold 3
+Test accuracy: 0.9411764705882353
+Test precision: 0.8904109589041096
+Test recall: 1.0
+
+Fold 4
+Test accuracy: 0.9485294117647058
+Test precision: 0.9113924050632911
+Test recall: 1.0
+
+Fold 5
+Test accuracy: 0.9338235294117647
+Test precision: 0.918918918918919
+Test recall: 0.9577464788732394
+
+Baseline: 0.8415841584158416
+Average test accuracy: 0.9411764705882352
+Standard deviation of accuracies: 0.006576670522058182
+Average test precision: 0.9029425786429917
+Average test recall: 0.9884242957746479
+
+Age (<75: 0, ≥75: 1)
+Fold 1
+Test accuracy: 0.9485294117647058
+Test precision: 0.9066666666666666
+Test recall: 1.0
+
+Fold 2
+Test accuracy: 0.9338235294117647
+Test precision: 0.8873239436619719
+Test recall: 0.984375
+
+Fold 3
+Test accuracy: 0.9411764705882353
+Test precision: 0.8904109589041096
+Test recall: 1.0
+
+Fold 4
+Test accuracy: 0.9558823529411765
+Test precision: 0.9230769230769231
+Test recall: 1.0
+
+Fold 5
+Test accuracy: 0.9338235294117647
+Test precision: 0.918918918918919
+Test recall: 0.9577464788732394
+
+Baseline: 0.8415841584158416
+Average test accuracy: 0.9426470588235294
+Standard deviation of accuracies: 0.008574929257125446
+Average test precision: 0.9052794822457182
+Average test recall: 0.9884242957746479
+
+BMI (<30: 0, ≥30: 1)
+Fold 1
+Test accuracy: 0.9485294117647058
+Test precision: 0.9066666666666666
+Test recall: 1.0
+
+Fold 2
+Test accuracy: 0.9338235294117647
+Test precision: 0.8873239436619719
+Test recall: 0.984375
+
+Fold 3
+Test accuracy: 0.9411764705882353
+Test precision: 0.8904109589041096
+Test recall: 1.0
+
+Fold 4
+Test accuracy: 0.9485294117647058
+Test precision: 0.9113924050632911
+Test recall: 1.0
+
+Fold 5
+Test accuracy: 0.9338235294117647
+Test precision: 0.918918918918919
+Test recall: 0.9577464788732394
+
+Baseline: 0.8415841584158416
+Average test accuracy: 0.9411764705882352
+Standard deviation of accuracies: 0.006576670522058182
+Average test precision: 0.9029425786429917
+Average test recall: 0.9884242957746479
+
+Smoking Status (Smoker: 0, Non-smoker:1)
+Fold 1
+Test accuracy: 0.9191176470588235
+Test precision: 0.8658536585365854
+Test recall: 1.0
+
+Fold 2
+Test accuracy: 0.9632352941176471
+Test precision: 0.9393939393939394
+Test recall: 0.9841269841269841
+
+Fold 3
+Test accuracy: 0.9558823529411765
+Test precision: 0.9178082191780822
+Test recall: 1.0
+
+Fold 4
+Test accuracy: 0.9705882352941176
+Test precision: 0.9473684210526315
+Test recall: 1.0
+
+Fold 5
+Test accuracy: 0.9705882352941176
+Test precision: 0.9565217391304348
+Test recall: 0.9850746268656716
+
+Baseline: 0.8415841584158416
+Average test accuracy: 0.9558823529411764
+Standard deviation of accuracies: 0.01917412472118428
+Average test precision: 0.9253891954583346
+Average test recall: 0.9938403221985311
+
+Pack Years (1-30: 0, ≥30: 1)
+Fold 1
+Test accuracy: 0.9565217391304348
+Test precision: 0.9649122807017544
+Test recall: 0.9482758620689655
+
+Fold 2
+Test accuracy: 0.9391304347826087
+Test precision: 0.8793103448275862
+Test recall: 1.0
+
+Fold 3
+Test accuracy: 0.9473684210526315
+Test precision: 0.9047619047619048
+Test recall: 1.0
+
+Fold 4
+Test accuracy: 0.9385964912280702
+Test precision: 0.9333333333333333
+Test recall: 0.9491525423728814
+
+Fold 5
+Test accuracy: 0.9824561403508771
+Test precision: 0.9682539682539683
+Test recall: 1.0
+
+Baseline: 0.8415841584158416
+Average test accuracy: 0.9528146453089246
+Standard deviation of accuracies: 0.0161952811322858
+Average test precision: 0.9301143663757093
+Average test recall: 0.9794856808883694
+
+Stage (I-II: 0, III-IV:1)
+Fold 1
+Test accuracy: 0.9338235294117647
+Test precision: 0.8846153846153846
+Test recall: 1.0
+
+Fold 2
+Test accuracy: 0.9632352941176471
+Test precision: 0.9402985074626866
+Test recall: 0.984375
+
+Fold 3
+Test accuracy: 0.9411764705882353
+Test precision: 0.9014084507042254
+Test recall: 0.9846153846153847
+
+Fold 4
+Test accuracy: 0.9779411764705882
+Test precision: 0.9605263157894737
+Test recall: 1.0
+
+Fold 5
+Test accuracy: 0.9632352941176471
+Test precision: 0.9444444444444444
+Test recall: 0.9855072463768116
+
+Baseline: 0.8415841584158416
+Average test accuracy: 0.9558823529411764
+Standard deviation of accuracies: 0.016109486985446054
+Average test precision: 0.9262586206032429
+Average test recall: 0.9908995261984392
+
+Predominant histological pattern (Lepidic:1, Papillary: 2, Acinar: 3, Micropapillary: 4, Solid: 5)
+Fold 1
+Test accuracy: 0.9779411764705882
+Test precision: 0.9577464788732394
+Test recall: 1.0
+
+Fold 2
+Test accuracy: 0.9485294117647058
+Test precision: 0.9130434782608695
+Test recall: 0.984375
+
+Fold 3
+Test accuracy: 0.9411764705882353
+Test precision: 0.8904109589041096
+Test recall: 1.0
+
+Fold 4
+Test accuracy: 0.9705882352941176
+Test precision: 0.9473684210526315
+Test recall: 1.0
+
+Fold 5
+Test accuracy: 0.9485294117647058
+Test precision: 0.9444444444444444
+Test recall: 0.9577464788732394
+
+Baseline: 0.8415841584158416
+Average test accuracy: 0.9573529411764705
+Standard deviation of accuracies: 0.01425788193357744
+Average test precision: 0.9306027563070588
+Average test recall: 0.9884242957746479