In these notes, we discuss some ideas on how to make a code run faster. As an example, we will demonstrate how to modify the code shown near the end of this page to make it run much faster.

Why is a Code Slow?

Usually we want to optimize a code because it runs very slowly. Sometimes a code running fairly fast becomes slow when it is used in a simulation in which it is called thousands of times. This is precisely the case for the code that performs the randomization tests near the end of this note. For convenience, we copy the code below.

computeR2 <- function(y,x) {
  summary(lm(y~x))$r.squared
}
survey <- read.csv("http://courses.atlas.illinois.edu/spring2016/STAT/STAT200/RProgramming/data/Stat100_2014spring_survey02.csv")
R20 <- computeR2(survey$gayMarriage, survey$ethnicity) # original R^2
set.seed(63741891)
t0 <- system.time( R2 <- replicate(5000, computeR2(sample(survey$gayMarriage), 
                                                   survey$ethnicity)) )
t0
   user  system elapsed 
   7.48    0.00    7.48 

It takes about 7.5 seconds to compute R2 in my office computer and the randomization experiments are done only 5000 times. The lm() function runs fairly fast when it is used on a data set that is not very large. When it is called thousands of times, on the other hand, it becomes very slow. For the randomization tests, we only need to compute R2. It is a huge waste of time to use the lm() function because it also calculates many other quantities we don’t need.

Optimized Code – Version 1

It is very easy to make the code run faster. We just don’t use the lm() function but compute R2 using the formula R2=SSB/SST, where
\[SSB = \sum_{i=1}^n (\hat{y}_i - \bar{y})^2\] \[SST = \sum_{i=1}^n (y_i - \bar{y})^2\] Note that in the randomization tests, we scramble y and so it doesn’t change SST and \(\bar{y}\). This means that SST and \(\bar{y}\) need only be computed once, not 5000 times. The only calculation that needs to be repeated is \(\hat{y}_i\), which is the mean of the group to which the ith observation belongs. When we scramble y, the group means change. The following is a new function that replaces computeR2():

computeR2b <- function(y,x,SST,bar_y) {
  means <- tapply(y,x,mean) # group means
  hat_y <- means[x]
  SSB <- sum((hat_y - bar_y)^2)
  SSB/SST
}

We can now do the randomization test again and see if it runs faster:

y <- survey$gayMarriage
x <- survey$ethnicity
n <- length(y)
bar_y <- mean(y)
SST <- (n-1)*var(y)
set.seed(63741891)
t1 <- system.time( R2b <- replicate(5000, computeR2b(sample(y),x,SST,bar_y)) )
t1
   user  system elapsed 
   1.72    0.00    1.72 

This version of the code is about 14 times faster. Before we can trust the code, we need to check that we still get the same result with the new code. We can test it by comparing R2 and R2b:

max(abs(R2-R2b))
[1] 4.475587e-16

This confirms that the two codes produce the same result within machine round-off error. Can we do better?

Optimized Code – Version 2

There are several things we can try to make the code run faster. Here is a list of things we can change to hopefully speed up the code.

Most of the changes are straightforward. The most tricky issue is the optimization of the tapply(y,x,mean) calculation. The idea is to find the indices in the data frame corresponding to each group. These indices are fixed since x is not changing. We can save the indices to a list. The following is a code that does that.

g <- length(levels(x)) # Number of groups
ind_g <- list() # Initialize a list
for (i in 1:g) {
  ind_g[[i]] <- which(as.integer(x)==i)
}

In our case, the number of groups g=5 and ind_g is a list of length 5. The which(as.integer(x)==i) function returns the indices for group i, with i=1, 2, 3, 4, and 5. Next, we want to count the number of observation for each group. This can be done easily using the table() function:

table(x)
x
      Asian       Black    Hispanic Mixed_Other       White 
        184          84          99          43         448 

What we really want is to use these numbers to calculate the group means, which are the sums divided by the numbers. So we want to pass 1/table(x) to the function. We can store these numbers to a vector f1oNg:

f1oNg <- unname(1/table(x))

where we used the unname function to strip off the names in the vector. It is time to try our second version of the code:

computeR2c <- function(y,bar_y,f1oSST,g,f1oNg,ind_g,hat_y) {
  for (i in 1:g) {
    hat_y[ind_g[[i]]] <- f1oNg[i]*sum(y[ind_g[[i]]])
  }
  f1oSST*sum((hat_y-bar_y)^2)
}

f1oSST <- 1/SST
hat_y <- rep(NA,n) # initialize hat_y
set.seed(63741891)
t2 <- system.time( R2c <- replicate(5000, 
                                    computeR2c(sample(y),bar_y,
                                               f1oSST,g,f1oNg,ind_g,hat_y)) )
t2
   user  system elapsed 
   0.38    0.00    0.38 

That’s about 1.5 times faster than the first version, 21 times faster than the unoptimized version! We need to make sure we didn’t make any mistakes. So we compare the result with that produced by the old code:

max(abs(R2-R2c))
[1] 4.475587e-16

Fantastic!

Gathering all the pieces, we have the complete second version of the code:

rm(list=ls()) # clear workspace

# Function that calculates R^2
computeR2c = function(y,bar_y,f1oSST,g,f1oNg,ind_g,hat_y) {
  for (i in 1:g) {
    hat_y[ind_g[[i]]] <- f1oNg[i]*sum(y[ind_g[[i]]])
  }
  f1oSST*sum((hat_y-bar_y)^2)
}

# load data
survey <- read.csv("http://courses.atlas.illinois.edu/spring2016/STAT/STAT200/RProgramming/data/Stat100_2014spring_survey02.csv")
y <- survey$gayMarriage
x <- survey$ethnicity

# Quantities that need to be computed only once
n <- length(y) # number of observations
bar_y <- mean(y)
SST <- (n-1)*var(y)
f1oSST <- 1/SST
g <- length(levels(x)) # Number of groups
ind_g <- list() # Initialize a list
for (i in 1:g) {
  ind_g[[i]] <- which(as.integer(x)==i)
}
f1oNg <- unname(1/table(x))
hat_y <- rep(NA, n) # Initialize hat_y
R20 <- computeR2c(y, bar_y,f1oSST,g,f1oNg,ind_g,hat_y) # original R^2

# Perform the randomization test
set.seed(63741891)
R2c <- replicate(5000, computeR2c(sample(y),bar_y,f1oSST,g,f1oNg,ind_g,hat_y))

Sorting a Data Frame and Cache Misses

The sort() function can be used to sort a vector. Sometimes we want to sort a data frame using the variable in a particular column. We can do this using the order() function. Suppose we want to sort the data frame survey by the variable ‘ethnicity’. The command is

survey2 <- survey[order(survey$ethnicity), ]

You can verify that in this new data frame survey2, ‘ethnicity’ is ‘Asian’ in the first 184 rows, ‘Black’ in the next 84 rows, ‘Hispanic’ in the next 99 rows, ‘Mixed_Other’ in the next 43 rows, and ‘White’ in the last 448 rows.

We can create the x and y vectors associated with this new data frame and then created the ind_g list the same way as above:

y <- survey2$gayMarriage
x <- survey2$ethnicity
g <- length(levels(x)) # Number of groups
ind_g <- list() # Initialize a list
for (i in 1:g) {
  ind_g[[i]] <- which(as.integer(x)==i)
}

Since ‘ethnicity’ is sorted in survey2, the first element in the list ind_g is just 1, 2, 3, …, 184. We can verify that:

ind_g[[1]]
  [1]   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
 [18]  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34
 [35]  35  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51
 [52]  52  53  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68
 [69]  69  70  71  72  73  74  75  76  77  78  79  80  81  82  83  84  85
 [86]  86  87  88  89  90  91  92  93  94  95  96  97  98  99 100 101 102
[103] 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
[120] 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136
[137] 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153
[154] 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
[171] 171 172 173 174 175 176 177 178 179 180 181 182 183 184

The second element in the ind_g list should be the 84 numbers following 184, i.e. 185, 186, …, 268. Let’s verify that:

ind_g[[2]]
 [1] 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201
[18] 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218
[35] 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235
[52] 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
[69] 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268

You may wonder what does this sorting process have to do with code optimization? The answer is that it won’t help much with the small data set that we have. But for a big data set with billions of observations, it may make the code faster by reducing the cache misses.

You probably have heard of the random access memory or RAM. Roughly speaking, there are two types of memories in RAM, the regular RAM and the cache memory. Data and programs storied in the cache memory can be accessed more quickly by the computer than those storied in the regular RAM. Cache memory is more expensive and smaller in capacity (in terms of GB) than the regular RAM. In the command hat_y[ind_g[[i]]] <- f1oNg[i]*sum(y[ind_g[[i]]]) in the above code, the computer has to access the memory locations corresponding to y[ind_g[[i]]], processes the data and puts the results to the memory locations corresponding to hat_y[ind_g[[i]]]. If not all of the memory locations are in the cache memory, the operations slow down because the computer has to access data outside the cache memory. This is called the cache miss1. When a vector is created, usually nearby indices are stored in nearby memory locations. Therefore, if the indices in ind_g[[i]] are consecutive integers, the memory locations of hat_y[ind_g[[i]]] and y[ind_g[[i]]] are likely to be close and they are also likely to be placed inside the cache memory during the calculations. This could reduce cache misses. Note that we can only hope to reduce cache misses, but we probably cannot eliminate them entirely. For example, scrambling the y variable most likely leads to cache misses for a large data set. However, the cache miss issue is only significant for a large data set. For our small data set, everything can be placed in the cache memory and there is no cache miss.

Here is our final version of the optimized code. There is only one line added to sort the data frame.

rm(list=ls()) # clear workspace

# Function that calculates R^2
computeR2c <- function(y,bar_y,f1oSST,g,f1oNg,ind_g,hat_y) {
  for (i in 1:g) {
    hat_y[ind_g[[i]]] <- f1oNg[i]*sum(y[ind_g[[i]]])
  }
  f1oSST*sum((hat_y-bar_y)^2)
}

# load data
survey <- read.csv("http://courses.atlas.illinois.edu/spring2016/STAT/STAT200/RProgramming/data/Stat100_2014spring_survey02.csv")
survey <- survey[order(survey$ethnicity),] # Sort the data frame by 'ethnicity'
y <- survey$gayMarriage
x <- survey$ethnicity

# Quantities that need to be computed only once
n <- length(y) # number of observations
bar_y <- mean(y)
SST <- (n-1)*var(y)
f1oSST <- 1/SST
g <- length(levels(x)) # Number of groups
ind_g <- list() # Initialize a list
for (i in 1:g) {
  ind_g[[i]] <- which(as.integer(x)==i)
}
f1oNg <- unname(1/table(x))
hat_y <- rep(NA, n) # Initialize hat_y
R20 <- computeR2c(y, bar_y,f1oSST,g,f1oNg,ind_g,hat_y) # original R^2

# Perform the randomization test
set.seed(63741891)
R2d <- replicate(5000, computeR2c(sample(y),bar_y,f1oSST,g,f1oNg,ind_g,hat_y))

When system.time() is used to time the calculation of R2d, we get the same timing result as before. So the sorting does not speed up the code, as we expect since our data set is small. It should be noted that since the survey data frame is sorted before doing the randomization test, the sample(y) function does not correspond to the previous versions even though we use the same seed number. As a result, R2d is no longer the same as R2. If you are worried about making mistakes in the coding, you can check it by running the unoptimized code on the sorted data frame and compare the result with the new code:

computeR2 <- function(y,x) {
  summary(lm(y~x))$r.squared
}
set.seed(63741891)
R2e <- replicate(5000, computeR2(sample(survey$gayMarriage),survey$ethnicity))
max(abs(R2e-R2d))
[1] 7.580742e-16

This confirms that the two codes produce the same result to machine round-off error.

Finally, we should point out that it is a good idea to check to see if there are already existing R packages for the calculations you want to do before writing your own code, especially if the code you want to write may involve complicated calculations. Most likely, codes in the existing packages written by experts are already battle-tested and well-optimized.





  1. The situation is even worse if the data set is so large that it cannot even be loaded entirely to the RAM. The computer will have to access data frequently from a hard drive (or from a remote location through the internet) during calculations. I have encountered this situation once or twice. It was an unpleasant experience.