Training a machine to determine whether a mushroom is edible

It's been awhile since my last blog post but we've been busy with a big move from Houston to Brooklyn. The opportunities in New York City for data science and AI seem endless! I've also been spending some time putting to practice my newly acquired knowledge of machine learning by browsing through open datasets.

One dataset that piqued my interest is the mushroom dataset from the UCI Machine Learning Repository describing different species from the genera Agaricus and Lepiota. The data are taken from The Audubon Society Field Guide to North American Mushrooms, which states "there is no simple rule for determining the edibility of a mushroom". Challenged by this bold claim, I wanted to explore if a machine could succeed here. In addition to answering this question, this post explores some common issues in machine learning and how to use Python's go-to machine learning library, Scikit-learn, to address them.

Table of contents

  1. Inspecting the dataset
  2. Data wrangling
  3. Training a model with the hold-out method
  4. Using nested cross-validation to evaluate performance
  5. Identifying the most influential features

1. Inspecting the dataset

Let's begin by loading the dataset.

In [2]:
df = pd.read_table('data/', delimiter=',', header=None)
0 1 2 3 4 5 6 7 8 9 ... 13 14 15 16 17 18 19 20 21 22
0 p x s n t p f c n k ... s w w p w o p k s u
1 e x s y t a f c b k ... s w w p w o p n n g
2 e b s w t l f c b n ... s w w p w o p n n m
3 p x y w t p f c n n ... s w w p w o p k s u
4 e x s g f n f w b k ... s w w p w o e n a g

5 rows × 23 columns

In [3]:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 8124 entries, 0 to 8123
Data columns (total 23 columns):
0     8124 non-null object
1     8124 non-null object
2     8124 non-null object
3     8124 non-null object
4     8124 non-null object
5     8124 non-null object
6     8124 non-null object
7     8124 non-null object
8     8124 non-null object
9     8124 non-null object
10    8124 non-null object
11    8124 non-null object
12    8124 non-null object
13    8124 non-null object
14    8124 non-null object
15    8124 non-null object
16    8124 non-null object
17    8124 non-null object
18    8124 non-null object
19    8124 non-null object
20    8124 non-null object
21    8124 non-null object
22    8124 non-null object
dtypes: object(23)
memory usage: 1.4+ MB

The dataset consists of 8124 training examples, each representing a single mushroom. The first column is the target variable containing the class labels, identifying whether the mushroom is poisonous or edible. The remaining columns are 22 discrete features that describe the mushroom in some observable way; their values are encoded by characters. For example, gill size is either broad (b) or narrow (n), and veil color can be brown (n), orange (o), white (w), or yellow (y). Each feature has numerous values so if you'd like to peruse the details you can find them in the data description.

Because the target variable contains discrete values, we'll need to train a classifier. But first, let's update the column labels.

In [4]:
column_labels = [
    'class', 'cap shape', 'cap surface', 'cap color', 'bruised', 'odor',
    'gill attachment', 'gill spacing', 'gill size', 'gill color', 
    'stalk shape', 'stalk root', 'stalk surface above ring',
    'stalk surface below ring', 'stalk color above ring',
    'stalk color below ring', 'veil type', 'veil color', 'ring number',
    'ring type', 'spore print color', 'population', 'habitat'

df.columns = column_labels

The data description indicates that the feature stalk root has some missing values, denoted by ?. In this analysis, we'll exclude any training example that has missing values for stalk root.

In [5]:
df = df[df['stalk root'] != '?']

Next, let's pull out the features and the target variable, and place them in their own tables.

In [6]:
X = df.loc[:, df.columns != 'class']
cap shape cap surface cap color bruised odor gill attachment gill spacing gill size gill color stalk shape ... stalk surface below ring stalk color above ring stalk color below ring veil type veil color ring number ring type spore print color population habitat
0 x s n t p f c n k e ... s w w p w o p k s u
1 x s y t a f c b k e ... s w w p w o p n n g
2 b s w t l f c b n e ... s w w p w o p n n m
3 x y w t p f c n n e ... s w w p w o p k s u
4 x s g f n f w b k t ... s w w p w o e n a g

5 rows × 22 columns

In [7]:
y = df['class'].to_frame()
0 p
1 e
2 e
3 p
4 e

While we're here, let's take a look at how the two classes are distributed.

In [8]:
e    3488
p    2156
Name: class, dtype: int64

Clearly, there are vastly more training examples for edible mushrooms versus poisonous ones. We'll have to take this imbalance into account when training and evaluating the classifier.

2. Data wrangling

Encoding categorical features

Most machine learning models expect the features to be continuous numerical variables. In addition, the last time I checked, Scikit-learn makes this mandatory. However, our features are all categorical variables! This means we'll need to encode them with numbers so we can perform the math required to train a classifier.

One encoding option is to convert the distinct values for each feature into integers (LabelEncoder from Scikit-learn is handy). For example, the values for veil color are brown, orange, white, and yellow, but they can be represented by 1, 2, 3 and 4, respectively. Unfortunately, this is strategy has nonsensical implications— yellow isn't four times the value of brown! Now if our features were ordinal categorical variables, such as T-shirt size (small, medium, large), this strategy could have worked.

Instead, our features are nominal categorical variables with no intrinsic order. Therefore, we'll need to perform one-hot encoding, in which each feature with $z$ possible values is converted into $z$ binary features, only one of which is "on". We could use OneHotEncoder from Scikit-learn to execute this strategy, but this preprocessor requires the categories to be already encoded as integers. We can actually skip this extra step if we instead use get_dummies() from Pandas, which does everything in one go and provides appropriate column labels to boot!

In [9]:
X_enc = pd.get_dummies(X)
cap shape_b cap shape_c cap shape_f cap shape_k cap shape_s cap shape_x cap surface_f cap surface_g cap surface_s cap surface_y ... population_n population_s population_v population_y habitat_d habitat_g habitat_l habitat_m habitat_p habitat_u
0 0 0 0 0 0 1 0 0 1 0 ... 0 1 0 0 0 0 0 0 0 1
1 0 0 0 0 0 1 0 0 1 0 ... 1 0 0 0 0 1 0 0 0 0
2 1 0 0 0 0 0 0 0 1 0 ... 1 0 0 0 0 0 0 1 0 0
3 0 0 0 0 0 1 0 0 0 1 ... 0 1 0 0 0 0 0 0 0 1
4 0 0 0 0 0 1 0 0 1 0 ... 0 0 0 0 0 1 0 0 0 0

5 rows × 98 columns

By performing one-hot encoding, we also dramatically expanded the feature space from 22 to 98. This is concerning because we're increasing the likelihood of overfitting when we train a model. Let's keep an eye out for this issue when we evaluate the performance of the classifier.

Standardizing the features

Some machine learning algorithms, such as principal components analysis, only work if the features are standardized to have zero mean and unit variance; others will converge faster. Only tree-based models see no real benefit from feature standardization, but in general, it doesn't hurt to standardize.

In [10]:
scaler = StandardScaler()
X_std = scaler.fit_transform(X_enc)

Encoding the target variable

Some machine learning classifiers in Scikit-learn prefer that the class labels in the target variable are encoded with numbers. Since we only have two classes, we can use LabelEncoder.

In [ ]:
le = LabelEncoder()
y_enc = le.fit_transform(y.values.ravel())

3. Training a model with the hold-out method

I like to use the law of parsimony when solving problems; this includes the selection of machine learning models. Therefore, we'll begin with a logistic regression classifier and go from there.

For a quick and dirty analysis, we'll use the holdout method (80/20 training and test split) to gauge how well the classifier is performing. As we had discovered the classes are imbalanced, we'll need to incorporate stratification to retain the same class distribution within the training and test sets.

In [11]:
X_train, X_test, y_train, y_test = train_test_split(

Let's now train a logistic regression classifier on the training set using the default hyperparameter ($L_2$ regularization, $\lambda = 1$) and evaluate its performance on the test set. Because the classes are imbalanced, we need to use a performance metric other than classification accuracy; the $F_1$ score should do the trick.

In [12]:
clf = LogisticRegression(), y_train)
y_pred = clf.predict(X_test)
metrics.f1_score(y_test, y_pred)

It looks like the classifier has a perfect $F_1$ score on the test set, but keep in mind this is just a single training/test split; we need to confirm this performance holds for other splits. In addition, we need to tune the regularization hyperparameter. Fortunately, there's a way to tackle both at the same time without introducing additional bias.

4. Using nested cross-validation to evaluate performance

Instead of using the same data to tune the hyperparameter and evaluate model performance, we'll utilize nested cross-validation to avoid risking optimistically biasing the model.

In [13]:
param_grid = [{'C': np.logspace(-3, 3, 10)}]

grid_search = GridSearchCV(
    cv=StratifiedShuffleSplit(n_splits=10, test_size=0.2, random_state=42),

scores = cross_val_score(
    cv=StratifiedShuffleSplit(n_splits=10, test_size=0.2, random_state=0),

One downside to using nested cross-validation is how computationally intensive it can be. The outer loop splits the data into training/test folds using 10-fold cross-validation and reports model performance, while the inner loop performs a grid search on each training fold of the outer loop to tune the hyperparameter with 10-fold cross-validation. In addition, each grid search tests 10 hyperparameter values. That means we've just trained 1000 models! Fortunately, n_jobs=-1 parallelizes the operations across all CPU cores and speeds up the computation considerably.

Let's take a look at model performance on each of the 10 test folds from the outer loop.

In [14]:
array([ 1.        ,  0.99822773,  1.        ,  1.        ,  0.99822773,
        1.        ,  1.        ,  1.        ,  1.        ,  1.        ])

Witnessing how consistent the performance is, we can now conclude the model is indeed performing well and not overfitting. This kind of stellar performance with a linear model, such as logistic regression, hints that the relationship between our features and the target variable is simplistic and highly linear. There's really no need to train a more complex machine learning model. Lastly, let's report the mean score as the final measure of model performance.

In [15]:

5. Identifying the most influential features

All that remains is to train the model on the entire dataset so it can be deployed. But first, we need to perform grid search once more to identify the optimal regularization hyperparameter.

In [16]:, y_enc)
{'C': 0.10000000000000001}

Using this hyperparameter value, let's train the final model.

In [17]:
final_clf = LogisticRegression(C=0.1), y_enc);

Now let's report the five features that are most strongly correlated with class, either positively or negatively, as determined by the magnitude of the learned parameters.

In [18]:
feature_ranks = pd.DataFrame(final_clf.coef_, index=['parameter value'])
feature_ranks.columns = X_enc.columns
feature_ranks.sort_values('parameter value', axis=1, ascending=False).T.head()
parameter value
odor_p 1.078392
odor_f 0.924890
spore print color_h 0.924890
odor_c 0.880472
spore print color_r 0.577942
In [19]:
feature_ranks.sort_values('parameter value', axis=1, ascending=True).T.head()
parameter value
odor_n -1.043962
odor_l -0.553868
odor_a -0.553868
stalk root_c -0.481090
spore print color_n -0.448844

Interesting! It looks like the mushroom's odor is a predominant factor in determining whether it's edible.

We can now use our model to decide if an Agaricus or Lepiota mushroom is edible. Evidently, this problem was a cakewalk for our machine learning model. Actually, I was hoping to find a dataset that would put up more of a challenge and allow me to troubleshoot the model, but it was still great for practicing with Scikit-learn. Nevertheless, this was a phenomenal example to showcase the power of machine learning—while experts concluded there is no rule that can be used to determine if these genera of mushrooms are edible, a machine with no domain knowledge figured it out without breaking a sweat!

Nevertheless, this model serves to assist experts, not replace them. False negatives in this model may be life-threatening; as such, the final word about consuming a wild mushroom should always come from an expert.

If you'd like to play around with the code, here's the GitHub repo. As always, don't hesitate to leave your comments below.


Comments powered by Disqus