A quick introduction to machine learning in R with caret

If you’ve been using R for a while, and you’ve been working with basic data visualization and data exploration techniques, the next logical step is to start learning some machine learning.

To help you begin learning about machine learning in R, I’m going to introduce you to an R package: the caret package. We’ll build a very simple machine learning model as a way to learn some of caret’s basic syntax and functionality.

But before diving into caret, let’s quickly discuss what machine learning is and why we use it.

What is machine learning?

Machine learning is the study of data-driven, computational methods for making inferences and predictions.

Without going into extreme depth here, let’s unpack that by looking at an example.

A simple example

Imagine that you want to understand the relationship between car weight and car fuel efficiency (I.e., miles per gallon); how is fuel efficiency effected by a car’s weight?

To answer this question, you could obtain a dataset with several different car models, and attempt to identify a relationship between weight (which we’ll call wt) and miles per gallon (which we’ll call mpg).

A good starting point would be to simply plot the data, so first, we’ll create a scatterplot using R’s ggplot:

require(ggplot2)

ggplot(data = mtcars, aes(x = wt, y = mpg)) +
  geom_point()

 
2016-04-06_quick-intro-to-machine-learning-in-R-with-Caret__scatter1

Just examining the data visually, it’s pretty clear that there’s some relationship. But if we want to make more precise claims about the relationship between wt and mpg, we need to specify this relationship mathematically. So, as we press forward in our analysis, we’ll make the assumption that this relationship can be described mathematically; more precisely, we’ll assume that this relationship can be described by some mathematical function, f(x). In the case of the above example, we’ll be making the assumption that “miles per gallon” can be described as a function of “car weight”.

Assuming this type of mathematical relationship, machine learning provides a set of methods for identifying that relationship. Said differently, machine learning provides a set of computational methods that accept data observations as inputs, and subsequently estimate that mathematical function, f(x); machine learning methods learn the relationship by being trained with an input dataset.

Ultimately, once we have this mathematical function (a model), we can use that model to make predictions and inferences.

How much math you really need to know

What I just wrote in the last few paragraphs about “estimating functions” and “mathematical relationships” might cause you to ask a question: “how much math do I need to know to do machine learning?”

Ok, here is some good news: to implement basic machine learning techniques, you don’t need to know much math.

To be clear, there is quite a bit of math involved in machine learning, but most of that math is taken care of for you. What I mean, is that for the most part, R libraries and functions perform the mathematical calculations for you. You just need to know which functions to use, and when to use them.

Here’s an analogy: if you were a carpenter, you wouldn’t need to build your own power tools. You wouldn’t need to build your own drill and power saw. Therefore, you wouldn’t need to understand the mathematics, physics, and electrical engineering principles that would be required to construct those tools from scratch. You could just go and buy them “off the shelf.” To be clear, you’d still need to learn how to use those tools, but you wouldn’t need a deep understanding of math and electrical engineering to operate them.

When you’re first getting started with machine learning, the situation is very similar: you can learn to use some of the tools, without knowing the deep mathematics that makes those tools work.

Having said that, the above analogy is somewhat imperfect. At some point, as you progress to more advanced topics, it will be very beneficial to know the underlying mathematics. My argument, however, is that in the beginning, you can still be productive without a deep understanding of calculus, linear algebra, etc. In any case, I’ll be writing more about “how much math you need” in another blog post.

Ok, so you don’t need to know that much math to get stared, but you’re not entirely off the hook. As I noted above, you still need to know how to use the tools properly.

In some sense, this is one of the challenges of using machine learning tools in R: many of them be difficult to use.

R has many packages for implementing various machine learning methods, but unfortunately many of these tools were designed separately, and they are not always consistent in how they work. The syntax for some of the machine learning tools is very awkward, and syntax from one tool to the next is not always the same. If you don’t know where to start, machine learning in R can become very confusing.

This is why I recommend using the caret package to do machine learning in R.

A quick introduction to caret

For starters, let’s discuss what caret is.

The caret package is a set of tools for building machine learning models in R. The name “caret” stands for Classification And REgression Training. As the name implies, the caret package gives you a toolkit for building classification models and regression models. Moreover, caret provides you with essential tools for:

– Data preparation, including: imputation, centering/scaling data, removing correlated predictors, reducing skewness
– Data splitting
– Model evaluation
– Variable selection

Caret simplifies machine learning in R

While caret has broad functionality, the real reason to use caret is that it’s simple and easy to use.

As noted above, one of the major problems with machine learning in R is that most of R’s different machine learning tools have different interfaces. They almost all “work” a little differently from one another: the syntax is slightly different from one modeling tool to the next; tools for different parts of the machine learning workflow don’t always “work well” together; tools for fine tuning models or performing critical functions may be awkward or difficult to work with. Said succinctly, R has many machine learning tools, but they can be extremely clumsy to work with.

Caret solves this problem. To simplify the process, caret provides tools for almost every part of the model building process, and moreover, provides a common interface to these different machine learning methods.

For example, caret provides a simple, common interface to almost every machine learning algorithm in R. When using caret, different learning methods like linear regression, neural networks, and support vector machines, all share a common syntax (the syntax is basically identical, except for a few minor changes).

Moreover, additional parts of the machine learning workflow – like cross validation and parameter tuning – are built directly into this common interface.

To say that more simply, caret provides you with an easy-to-use toolkit for building many different model types and executing critical parts of the ML workflow. This simple interface enables rapid, iterative modeling. In turn, this iterative workflow will allow you to develop good models faster, with less effort, and with less frustration.

Caret’s syntax

Now that you’ve been introduced to caret, let’s return to the example above (of mpg vs wt) and see how caret works.

Again, imagine you want to learn the relationship between mpg and wt. As noted above, in mathematical terms, this means identifying a function, f(x), that describes the relationship between wt and mpg.

Here in this example, we’re going to make an additional assumption that will simplify the process somewhat: we’re going to assume that the relationship is linear; we’ll assume that that it can be described by a straight line of the form f(x) = \beta_0 + \beta_1x.

In terms of our modeling effort, this means that we’ll be using linear regression to build our machine learning model.

Without going into the details of linear regression (I’ll save that for another blog post), let’s look at how we implement linear regression with caret.

The train() function

The core of caret’s functionality is the train() function.

train() is the function that we use to “train” the model. That is, train is the function that will “learn” the relationship between mpg and wt.

Let’s take a look at this syntactically.

Here is the syntax for a linear regression model, regressing mpg on wt.

#~~~~~~~~~~~~~~~~~~~~~~~~~~
# Build model using train()
#~~~~~~~~~~~~~~~~~~~~~~~~~~

require(caret)

model.mtcars_lm <- train(mpg ~ wt
                        ,data = mtcars
                        ,method = "lm"
                        )

 
That's it. The syntax for building a linear regression is extremely simple with caret.

Now that we have a simple model, let's quickly extract the regression coefficients and plot the model (i.e., plot the linear function that describes the relationship between mpg and wt).

#~~~~~~~~~~~~~~~~~~~~~~~~~~
# Retrieve coefficients for
#  - slope
#  - intercept
#~~~~~~~~~~~~~~~~~~~~~~~~~~

coef.icept <- coef(model.mtcars_lm$finalModel)[1]
coef.slope <- coef(model.mtcars_lm$finalModel)[2]


#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Plot scatterplot and regression line
#  using ggplot()
#~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

ggplot(data = mtcars, aes(x = wt, y = mpg)) +
  geom_point() +
  geom_abline(slope = coef.slope, intercept = coef.icept, color = "red")

 
2016-04-06_quick-intro-to-machine-learning-in-R-with-Caret__regression1

Now, let's look more closely at the syntax and how it works.

When training a model using train(), you only need to tell it a few things:
- The dataset you're working with
- The target variable you're trying to predict (e.g., the mpg variable)
- The input variable (e.g., the wt variable)
- The machine learning method you want to use (in this case "linear regression")

Formula notation

In caret's syntax, you identify the target variable and input variables using the "formula notation."

The basic syntax for formula notation is y ~ x, where y is your target variable, and x is your predictor.

Effectively, y ~ x tells caret "I want to predict y on the basis of a single input, x."

Now, with this knowledge about caret's formula syntax, let's reexamine the above code. Because we want to predict mpg on the basis of wt, we use the formula mpg ~ wt. Again, this line of code is the "formula" that tells train() our target variable and our input variable. If we translate this line of code into English, we're effectively telling train(), "build a model that predicts mpg (miles per gallon) on the basis of wt (car weight)."

The data = parameter

The train() function also has a data = parameter. This basically tells the train() function what dataset we're using to build the model. Said differently, if we're using the formula mpg ~ wt to indicate the target and predictor variables, then we're using the data = parameter to tell caret where to find those variables.

So basically, data = mtcars tells the caret function that the data and the relevant variables can be found in the mtcars dataset.

The method = parameter

Finally, we see the method = parameter. This parameter indicates what machine learning method we want to use to predict y. In this case, we're building a linear regression model, so we are using the argument "lm".

Keep in mind, however, we could select a different learning method. Although it's beyond the scope of this blog post to discuss all of the possible learning methods that we could use here, there are many different methods we could use. For example, if we wanted to use the k-nearest neighbor technique, we could use the "knn" argument instead. If we did that, train() would still predict mpg on the basis of wt, but would use a different statistical technique to make that prediction. This would yield a different model; a model that makes different predictions.

Again, it's beyond the scope of this post to discuss all of the different model types. However, as you learn more about machine learning, and want to try out more advanced machine learning techniques, this is how you can implement them. You simply change the learning method by changing the argument of the method = parameter.

This is a good place to reiterate one of caret's primary advantages: switching between model types is extremely easy when we use caret's train() function. Again, if you want to use linear regression to model your data, you just type in "lm" for the argument to method =; if you want to change the learning method to k-nearest neighbor, you just replace "lm" with "knn".

Caret's syntax allows you to very easily change the learning method. In turn, this allows you to "try out" and evaluate many different learning methods rapidly and iteratively. You can just re-run your code with different values for the method parameter, and compare the results for each method.

Next steps

Now that you have a high-level understanding of caret, you're ready to dive deeper into machine learning in R.

Keep in mind though, if you're new to machine learning, there's still lots more to learn. Machine learning is intricate and fascinatingly complex. Moreover, caret has a variety of additional tools for model building. We've just scratched the surface here.

At Sharp Sight Labs, we'll be creating blog posts to teach you more of the basics, including tutorials about:
- the bias-variance tradeoff
- deeper looks at regression
- classification
- data preparation
- and more.

In the mean time, send me your questions.

If you have questions about machine learning, or topics you're struggling with, sign up for the email list. Once you're on the email list, reply directly to any of the emails and send in your questions.

Joshua Ebner

Joshua Ebner is the founder, CEO, and Chief Data Scientist of Sharp Sight.   Prior to founding the company, Josh worked as a Data Scientist at Apple.   He has a degree in Physics from Cornell University.   For more daily data science advice, follow Josh on LinkedIn.

9 thoughts on “A quick introduction to machine learning in R with caret”

  1. Thank you for this article, this is a nice easy intro into ML and the CARET package. I’m excited to start to play around with it and learn how to leverage ML.

    Reply
  2. Thanks for the article. I have base knowledge in linear regression but i didn’t know caret.
    For linear regression, is caret restricted to a unique independant variable ?

    Reply
  3. Well, I got to install caret recently (I believe on your recommendation). I’m yet to use it, but I must say that after a couple of years of using R, I’ve never quite run into a massive package like this one! Hmm, I’m really looking forward to working with it…

    Reply
  4. I am learning data analysis with R through most of your material, and just finished the last video on your course. For machine learning, what type of modeling would likely work best where I want to predict number of cases shipped based on what month/year/ # of working days in the month? I’ve been trying to build a model to do our forecasting based on historical input, however the amount of data we have is pretty thin when going by monthly, as it only goes back about 3 years. So far I’ve not had much luck finding modeling tutorials that point me in the right direction.

    Reply
  5. Great introductory article ! Thank you for the example code and explanations ! Please follow up with different methods soon !

    Reply
  6. The cool thing is simple swap of model choice. Admittedly caret isn’t needed for this simple example at all. Basically we could have just put the regression line geometry in directly with ggplot:

    ggplot(data=mtcars,aes(wt,mpg))+geom_point() + geom_smooth(formula= y~x,method = ‘lm’,se=FALSE)

    Nevertheless, the example was helpful for me to get a simple overview.

    Reply
    • The example isn’t to say that caret is the best tool for building a simple regression line, but rather to show how caret works with a very simple example.

      My philosophy with teaching complicated things (like ML) is to start with ultra simple example to help people develop intuition, and then increase complexity.

      Reply

Leave a Comment