While machine learning refers to any model used to find patterns in data, usually to make a prediction or classification, there is a wide range of potential methods to find these patterns. As described in the previous post of this series, various machine learning architectures can be designed to suit many different types of tasks – supervised, unsupervised, and reinforcement learning. Each of these approaches suits a different objective for the analyst – whether it be to predict a value or classification, find natural groups in some data, or find a workable strategy to a complicated task.
In addition to the kind of problem they aim to solve, these models can also vary in their “complexity” – effectively, how many parameters they utilize, and how exactly these parameters interact with one another. As these methods become more complex, they become trickier for human observers to understand exactly how they make their decisions. In a previous blog, we discussed the need for greater investment into research on AI bias and explainability. The ability to interpret how a model makes its decisions can be a crucial tool for determining whether a system suffers from bias, and if so, whether technical or implementation methods can be employed to reduce or eliminate it.
In order to understand this distinction, assume that an analyst is attempting to predict rent stress using a variety of factors available using American Community Survey data. This time, however, the analyst has data broken down by census tracts – a much smaller geographical breakdown than ward-level data – and has more features available for each observation: percentage without a vehicle, in a >20 unit complex, with broadband access, and with access to a computer.
Using the 177 census tracts in DC with renters represented in the ACS, the analyst could use the percentage with broadband access as a predictor for rent-stress. The interpretation of such a model would be easy to interpret, take 85% (the ‘intercept’ term) and subtract the percentage in a tract with broadband internet times .57 (the ‘slope’ term; Figure 1). The slope and intercept are referred to as ‘parameters’ of the model – ultimately, these are the values that the model attempts to optimize to get the most accurate results possible.
An analyst is unlikely to use such a simple model to explain a complex phenomenon such as rent stress in a population, as no single factor can explain rent stress. As a result, they would likely incorporate as many potentially useful features as possible. Despite this increase in the number of features used for a given prediction, the interpretation of how the model makes decisions remains the same – take the intercept, multiply a slope by each input feature, and take the sum of the parts (Figure 2). Now, there are five total parameters for the model to estimate, but the interpretation is the same – each slope parameter is that feature’s influence on the variable of interest.
Figure 1: A simple model to predict rent stress in census tracts using percentage with broadband internet (red line) and a more complex model to predict rent stress in census tracts using four features (percentage without a vehicle, in a >20 unit complex, with access to broadband internet, and with access to a computer; green line). The horizontal axis (% with broadband internet) of this plot is the most influential of the features in this model, so the predicted values are similar to the previous single-feature model, but the unevenness of this model can be attributed to the influence of the other three features in the model other than broadband internet access.
The model depicted here (linear regression) is one of the most straightforward predictive models to interpret. However, the general trend is that many conventional machine learning approaches can be readily understood by mathematicians and ML experts, even by those without any mathematical background. While a linear regression design can change how interpretable it is, the prediction made will always be a linear combination of input features – allowing outside observers to easily understand why a prediction was made. Many other simple ML tools operate similarly (Figure 2).
Figure 2. Above: Example design of a linear regression model. Each feature is multiplied by a slope parameter (represented by the lines) and added together.
Below: Interpretation of various machine learning tools: (a) linear regression, (b) decision trees, and (c) k-nearest neighbors.
Deep learning is employed to solve the same kinds of problems as ‘conventional’ machine learning tools, but due to a major difference in these model’s architecture, it can be challenging to understand how their decisions are made. Any machine learning tool operates by estimating ‘parameters’ that lead to the best result possible. There is a parameter for each input feature in the simple linear model case, as described above.
Deep learning models are no different, but they incorporate far more features than almost any conventional tool to make predictions. This is because these models break down prediction and classification as a nested hierarchy of features. Instead of using a combination of some input features to make a prediction directly, deep learning generates new features from these input features as part of the training process.
Instead of directly estimating the outcome of interest using the input features, the model incorporates ‘neurons’ or ‘nodes’ in between the features and the outcome in a process that was originally modeled after the neural networks that make up the human brain. These nodes, organized in ‘layers,’ influence one another and eventually arrive at the prediction of interest.
Figure 3: A simple neural network. Even when using a very ‘shallow’ neural network, there is much greater complexity inherent to a neural network compared to a linear model – and the parameters estimated no longer have a direct interpretation with respect to the predicted percentage of rent stress.
This process of feature generation is what allows deep learning to perform so well in complex environments. Deep learning is particularly useful in image recognition tasks – given images of faces, for instance, a deep learning model may begin by simply finding ‘edges’ between light and dark spots before using these edges to detect more complicated features (eyes, nose, mouth, etc.) before finally combining these to detect or identify a face. A less complex model might require an analyst to label these facial features themselves, requiring domain-specific knowledge and labor input. As a result, the parameters estimated in the neural network lose the useful interpretation that they carried in the linear model case – instead of representing the effect of a feature on the outcome of interest, they represent how much one node influences the next layer of nodes. Additionally, in practice, the sub-features engineered by the model rarely have well-defined human interpretations.
In relatively simple neural networks, like the one presented in Figure 3, it may be possible to glean some meaning from these individual nodes; however, in practice, these neural networks can contain hundreds of layers with hundreds of nodes per layer for complex tasks in industries. As one extreme case, GPT3, OpenAI’s flagship natural language processing tool built using a deep neural network, uses 175 billion parameters. For context, the simple linear model contains two parameters, where the more complex model contains five – as the number of parameters grows, computation costs increase dramatically, while the parameters themselves lose their meaning compared to simpler models.
In response to this inverse relationship between model complexity and interpretability, designers of these tools employ various methods to better explain and interpret predictions, encompassing the growing field of ‘explainable AI’ (XAI). In addition to a litany of model-specific approaches which go beyond the scope of this post, some model-agnostic approaches include:
- Model simplification. In many cases, simpler models (including the linear regression approach outlined above) offer comparable levels of accuracy to deep learning approaches while being highly interpretable. The trade-off here is relatively straightforward – deep learning models tend to be more accurate than simpler models due to their ability to model more complex relationships. However, how much more accurate they are depends on how complex the data generation process is in the first place – if an analyst is modeling something with well-defined simple relationships, a deep learning approach may not be worth the additional complexity.
- Secondary modeling of inputs and outputs. While it may seem counter-intuitive, it is possible to train a neural network to make predictions and then train a second, more interpretable model to use the neural network’s output as a target variable. This approach preserves the predictive power of deep learning while allowing analysts to leverage simpler models to audit the decision-making process – however, this approach tends to struggle to explain outliers, which may be the most important to disambiguate depending on the application.
- Feature dropout/sensitivity analysis. Various subsets of the original features can be fit using a machine learning tool to tell which features are most effective at increasing accuracy. Similarly, data can be modified to test which features suffer most from the introduction of noise. The downside is typically computation – for models that are hugely computationally intensive to fit, repeatedly training can be a costly endeavor that may or may not lead to significant insights into the model’s decision-making process.
Ultimately, interpretability is paramount when attempting to understand whether a machine learning tool is discriminating against protected groups, robust to adversarial attacks from bad actors, and well-suited for unseen data. The ability to explain the decisions made by these models represents a major challenge for researchers and businesses alike – but unlocking this ability would allow AI innovation to flourish while being kept in line with inclusive values.