Friday, April 20, 2018

Why matrix multiplication works in the way it does?

Matrices can be thought of as a collection of values organised in the form of rows and columns. Apart from the many operations that are allowed over matrices, matrix multiplication is a very commonly used operation. But, it may make one wonder on what may have motivated its algorithm where you multiply element-wise the row X of the first matrix with the column Y of the second matrix, and then sum the results to obtain the element at position (X, Y) in the result matrix. In this post, I am trying to address that aspect.
To understand the motivation behind this operation, let's first take a look at vectors, which are a simpler form of matrices i.e. a matrix having just a single row or a single column. An example of a vector can be a point in the 3D physical space, which will be basically a row or a column with 3 elements - each representing the distance of the point from the origin along the three dimensions. Another example can be a pixel in a RGB  image, which is again a vector of 3 elements each representing the intensity of the respective color. Now, let's look at a bit more complex example based on the physical space itself (like our first example on vectors) that will provide an insight into the utility of a matrix and also explain the idea behind the matrix multiplication algorithm. Consider a point (x, y) in the 2D space i.e. a 2D vector, which we would like to map to another point in a 3D space (i.e. a 3D vector) using the translation formula: (2x + 3y, 3x + 2y, 4x + y). So, to obtain the corresponding 3D vector, one can just fill the values of x and y from the input 2D vector in the translation formula. That's simple, but the important part is that, this translation can be accomplished by using a matrix where we multiply it the with input 2D vector to obtain the output 3D vector , as shown in the image below:

Looking at the first equation in the image, you may guess how we reached to the 3x2 matrix in there -  it's basically just derived from the coefficients in the translation formula that was mentioned above. By multiplying this matrix to any 2D vector, we can obtain corresponding 3D vector based on the translation formula! So, what matrix multiplication simply does is applying a translation operation on an input vector where the elements in the matrix just represent the amount of contribution of the input vector elements that is needed to compute the output vector. For example, in this case, to compute first element (i.e. 8) in the output vector, two times the contribution from the first element (i.e. 1) of the input vector and three times the contribution from the second element (i.e. 2) is used i.e. (2*1 + 3*2 = 8), which is the basic step in the matrix multiplication algorithm.
One may question, why not just directly use the translation formula? Why represent this in matrix form? The answer here (I think) is that using the matrix form and the matrix multiplication operation simplifies the translation especially if there is a huge input data-set. For example, if you want to apply the same transformation over 1000 such 2D vectors, then by just stacking them up column-wise in a matrix (which will than be of dimensions 2x1000) and multiplying it with the same 2x3 transformation matrix will give us all the 1000 3D vectors! This is somewhat shown in the second equation in the figure above. This representation immensely helps in doing such computations inside a computer that can process huge chunks of data in one go. In fact, the processors today (actually, since past 10 yrs) are capable of performing SIMD instructions and scientific libraries like in MATLAB OR numpy in python utilize these instructions to perform matrix operations really fast! I have written a small python script to show the amount of speed-up that you may gain when using matrix multiplication instead of multiple scalar multiplications. Here I just use the same example as in the image above and run/compare the two over 1 million 2D vectors. I make use of the numpy library of python to do matrix multiplication i.e. using numpy.dot(<mat1>, <mat2>) method in there. The amount of speed-up I see on my windows 10 laptop (with Intel Core i5) is approx. 100x!

No comments:

Post a Comment