Efficient Matlab (II): Kmeans clustering algorithm

Kmeans is the most simple and widely used clustering algorithm. Detail description can be found in the wiki page.

Given an initial set of k means, the algorithm proceeds by alternating between two steps until converge:
(1) Assignment step: Assign each sample point to the cluster with the closest mean.
(2) Update step: Calculate the new means to be the centroid of the sample points in the cluster.
The algorithm is deemed to have converged when the assignments no longer change.

There is a built-in kmeans function in Matlab. Again, it is far from efficient. The implementation in Matlab is naive. It is something like

while !converged
  for each point
    assign label
  end
  for each cluster
    compute mean
  end
end

There are at least two layers of loops which hurt the efficiency badly. Here, we will use some tricks to get ride of the inner loops by verctorization.

The assignment step is to find the nearest mean to each point. Therefore, we can utilize the verctorized version of pairwise distance function we wrote in the previous post to find the nearest neighbors:

[~,label] = min(sqDistance(M,X),[],1);

where M is the mean matrix and X is the sample matrix. the sqDistance function is just a one-liner:

function D = sqDistance(X, Y)
D = bsxfun(@plus,dot(X,X,1)',dot(Y,Y,1))-2*(X'*Y);

For our purpose (find the nearest neighbor), we do not need to compute the dot product for sample points every time we compute the assignment. We can pre-compute it. Analyze a little further, we can see that we do not even need to compute it at all, since it does not affect the ranking. Therefore, we can write the assignment step efficiently as

[~,label] = max(bsxfun(@minus,m'*X,dot(m,m,1)'/2),[],1);

Vectorizing the update step is a little tricky. We can first build a (K x n) indicator matrix E to indicate the membership of each point to each cluster. If sample i is in cluster k then E(k,i)=1/n(k), otherwise E(k,i)=0. Here n(k) is the number of samples in cluster k. Then the mean matrix M can be computed by X*E. The update step now can be written as:

    E = sparse(1:n,label,1,n,k,n);  % transform label into indicator matrix
    m = X*(E*spdiags(1./sum(E,1)',0,k,k));    % compute m of each cluster

Note that E is a sparse matrix. Matlab automatically optimizes the matrix multiplication between a sparse matrix and a dense matrix. It is far more efficient than multiplying two dense matrices if the sparse matrix is indeed sparse.

Putting everything together, we have a very concise implementation (10 lines of code):

function label = litekmeans(X, k)
n = size(X,2);
last = 0;
label = ceil(k*rand(1,n));  % random initialization
while any(label ~= last)
    E = sparse(1:n,label,1,n,k,n);  % transform label into indicator matrix
    m = X*(E*spdiags(1./sum(E,1)',0,k,k));    % compute m of each cluster
    last = label;
    [~,label] = max(bsxfun(@minus,m'*X,dot(m,m,1)'/2),[],1); % assign samples to the nearest centers
end

The source code can be downloaded from here. This function is usually one or two order faster than the built-in Matlab kmeans function depending on the data set used. Again, the power of the vectorization is tremendous.

The takeaway of this post might be:
(1) Verctorizaton!
(2) Analyzing code to remove redundant computation.
(3) Matrix multiplication between a sparse matrix and a dense matrix is efficiently optimized by Matlab.

Posted in Matlab | Tagged , , | 10 Comments

Efficient Matlab (I): pairwise distances

I want to share some tricks for making Matlab function more efficient and robust. I decide to write a series of blog posts. This is the first one of this series, in which I want to show a simple function for computing pairwise Euclidean distances between points in high dimensional vector space.

The most important thing for efficient Matlab code is VECTORIZATION! Here I will take distance function as an example to demonstrate how accelerate your code by getting ride of “for-loops”.
Say, we have two sets of points X and Y in d-dimensional space. We want to compute the distances between any point in X and any other point in Y. This function is very useful. For example, if you want to find the k-nearest-neighbor of a point x in a database Y, an naive method is to first compute the distances between the point x and all points in Y and then sort the distances to find the k smallest ones.
We can naively implement the function by nested for loops:
function D = dumDistance(X,Y)
n1 = size(X,2);
n2 = size(Y,2);
D = zeros(n1,n2);
for i = 1:n1
    for j = 1:n2
        D(i,j) = sum((X(:,i)-Y(:,j)).^2);
    end
end
Here, each column of matrices X and Y is the vector representation of a point in d-dimensional space. The nested for loop is hurting performance so much. Actually, this is how the pdist function in Matlab is implemented basically (That is how dumb sometime Matlab can be). We wanna get ride of the for-loops and vectorize the code as much as possible. Before write any Matlab code, a good practice is to first write your algorithm down on a paper using matrix notation.
We denote x_i as a vector in X and y_j as a vector in Y. The square of distance d_{ij} between x_i and y_j is
d_{ij}^2 = \|x_i-y_j\|^2=\|x_i\|^2+\|y_j\|^2-2<x_i,y_j>,
where <\cdot,\cdot> is dot product. d_{ij}^2 can be considered as an entry of matrix D. Therefore, we can write the formula in matrix form as:
D = \bar x1'+1\bar y'-2X'Y

where, \bar x is column vector of squared norms of all vectors in X. We can directly implement this by a one-liner:

function D = sqDistance(X, Y)
D = bsxfun(@plus,dot(X,X,1)',dot(Y,Y,1))-2*(X'*Y);

In my rough tests, the new code is about 60+ times faster than the old one for random dense matrices. For sparse matrices, the gap is even larger. That is just the power of vectorization.

If you want to compute the pairwise distances between all point pairs in a point set, you can simply replace the Y with X in the function. Matlab will do extra optimization for the matrix product like X’*X, where only half of the computation is done. A by-product is that the result matrix D is ensured to be symmetric, which usually cannot be guaranteed due to numerical representation of float number.

The takeaway of this post might be:
(1) Verctorizaton!
(2) Matrix multiplication like M=X’*X is efficient (cost half time of ordinary matrix multiplication) and D is guaranteed to be symmetric.
(3) Write your algorithm in Matrix form before writing the code.

Posted in Matlab | Tagged | 17 Comments

Efficient F#: tail recursive quicksort via continuation monad

Recently, I become a fan of F#. This post is basically a practice to implement the quick sort algorithm in F#. The first version coming to mind is like this:

let rec someSort list =
    match list with
    | [] -> []
    | [x] -> [x]
    | x::xs ->
        let l, r = List.partition ((>) x) xs
        let ls = someSort l
        let rs = someSort r
        ls @ (x::rs)

Since I don’t like the function interface to be a recursive one, I rewrite it like this

let someSort list =
    let rec loop list =
        match list with
        | [] -> []
        | [x] -> [x]
        | x::xs ->
            let l, r = List.partition ((>) x) xs
            let ls = loop l
            let rs = loop r
            ls @ (x::rs)
    loop list

This implementation has several problems. First, it is not tail recursive, which means potentially it can blow you call stack. We can transform this implementation into a tail recursive one using so called continuation passing style (a.k.a. CPS):

let someSortCont list =
    let rec loop list cont =
        match list with
        | [] -> cont []
        | x::[] -> cont [x]
        | x::xs ->
            let l, r = List.partition ((>) x) xs
            loop l (fun ls ->
            loop r (fun rs ->
            cont (ls @ (x::rs))))
    loop list (fun x -> x)

To make the implementation easy to read (or write), we can utilize the continuation monad, which is called computation expression or workflow in F#. We define the workflow as

type ContinuationBuilder() =
    member this.Bind (m, f) = fun c -> m (fun a -> f a c)
    member this.Return x = fun k -> k x
let cont = ContinuationBuilder()

Then, the CPS sorting function can be implemented as

let someSortMonad list =
    let rec loop list =
        cont {
            match list with
            | [] -> return []
            | x::xs ->
                let l, r = List.partition ((>) x) xs
                let! ls = loop l
                let! rs = loop r
                return (ls @ (x::rs))
        }
    loop list (fun x -> x)

Everything looks good now, except this sorting algorithm actually is not the quick sort. The concatenation operator @ for list type is linear in its left parameter. A common optimization technique is to define functions not as returning a list but having accumulator list parameter to maintain the function result. We can apply this trick to the first version we wrote (someSort) to get a real quick sort:

let quickSort list =
    let rec loop list acc =
        match list with
        | [] -> acc
        | x::[] -> x::acc
        | x::xs ->
            let l, r = List.partition ((>) x) xs
            let rs = loop r acc
            loop l (x::rs)
    loop list []

Then, we transform it to a tail recursive one using CPS:

let quickSortCont list =
    let rec loop list acc cont =
        match list with
        | [] -> cont acc
        | x::[] -> cont (x::acc)
        | x::xs ->
            let l, r = List.partition ((>) x) xs
            loop r acc (fun rs ->
            loop l (x :: rs) cont)
    loop list [] (fun x -> x)

Finally, by utilizing the continuation monad we wrote, we further simplify the algorithm as:

let quickSortMonad list =
    let rec loop list acc=
        cont {
            match list with
                | [] -> return acc
                | x::[] -> return x::acc
                | x::xs ->
                    let l, r = List.partition ((>) x) xs
                    let! rs = loop r acc
                    let! s = loop l (x::rs)
                    return s
            }
    loop list [] (fun x -> x)

I did some basic testing to see how fast each implementation runs. Here is the test script:

let rand = new System.Random()
let data = List.init 100000 (fun _ -> rand.NextDouble())

let test f x =
    let sw = Stopwatch()
    sw.Start()
    f x |> ignore
    sw.Stop()
    sw.ElapsedMilliseconds

printf “someSort: %dms\n” (test someSort data)
printf “someSortCont: %dms\n” (test someSortCont data)
printf “someSortMonad: %dms\n” (test someSortMonad data)
printf “quickSort: %dms\n” (test quickSort data)
printf “quickSortCont: %dms\n” (test quickSortCont data)
printf “quickSortMonad: %dms\n” (test quickSortMonad data)
printf “List.sort: %dms\n” (test List.sort data)

Here is the result on my crappy laptop (Core 2 duo 2.2GHz, 4GB RAM):

someSort: 854ms
someSortCont: 918ms
someSortMonad: 896ms
quickSort: 791ms
quickSortCont: 796ms
quickSortMonad: 834ms
List.sort: 68ms
We can see that quickSort is a little faster than someSort. By adopting the CPS and the monad tricks, we slow down the algorithm a little bit, which is understandable. However, none of the above implementation can compete with the highly optimized  built-in List.sort function :(. Still it is a fun practice :).
P.S. this blog post is inspired by various online materials. I can not enumerate all of them. Many thanks to those authors.
Any suggestion for further improvement are welcome.
Posted in F# | Tagged | 3 Comments

Hello world!

I am interested in statistical inference especially Bayesian inference. I am also a fan of various interesting programming languages, especially functional programming languages. This blog is going to be a mixture of programming and statistics stuff. Hope anyone who reads this blog can have some fun.

Posted in Uncategorized | Leave a comment