Use Extension method to write cleaner code
This blog post will show you step by step to refactoring some code to be more readable (at least what I think). Patrik Löwnedahl gave me some of the ideas when we where talking about making code much cleaner.
The following is an simple application that will have a list of movies (Normal and Transfer). The task of the application is to calculate the total sum of each movie and also display the price of each movie.
class Program { enum MovieType { Normal, Transfer } static void Main(string[] args) { var movies = GetMovies(); int totalPriceOfNormalMovie = 0; int totalPriceOfTransferMovie = 0; foreach (var movie in movies) { if (movie == MovieType.Normal) { totalPriceOfNormalMovie += 2; Console.WriteLine("$2"); } else if (movie == MovieType.Transfer) { totalPriceOfTransferMovie += 3; Console.WriteLine("$3"); } } } private static IEnumerable<MovieType> GetMovies() { return new List<MovieType>() { MovieType.Normal, MovieType.Transfer, MovieType.Normal }; } }
In the code above I’m using an enum, a good way to add types (isn’t it ;)). I also use one foreach loop to calculate the price, the loop has a condition statement to check what kind of movie is added to the list of movies. I want to reuse the foreach only to increase performance and let it do two things (isn’t that smart of me?! ;)).
First of all I can admit, I’m not a big fan of enum. Enum often results in ugly condition statements and can be hard to maintain (if a new type is added we need to check all the code in our app to see if we use the enum somewhere else). I don’t often care about pre-optimizations when it comes to write code (of course I have performance in mind). I rather prefer to use two foreach to let them do one things instead of two. So based on what I don’t like and Martin Fowler’s Refactoring catalog, I’m going to refactoring this code to what I will call a more elegant and cleaner code.
First of all I’m going to use Split Loop to make sure the foreach will do one thing not two, it will results in two foreach (Don’t care about performance here, if the results will results in bad performance, you can refactoring later, but computers are so fast to day, so iterating through a list is not often so time consuming.)
Note: The foreach actually do four things, will come to is later.
var movies = GetMovies(); int totalPriceOfNormalMovie = 0; int totalPriceOfTransferMovie = 0; foreach (var movie in movies) { if (movie == MovieType.Normal) { totalPriceOfNormalMovie += 2; Console.WriteLine("$2"); } } foreach (var movie in movies) { if (movie == MovieType.Transfer) { totalPriceOfTransferMovie += 3; Console.WriteLine("$3"); } }
To remove the condition statement we can use the Where extension method added to the IEnumerable<T> and is located in the System.Linq namespace:
foreach (var movie in movies.Where( m => m == MovieType.Normal)) { totalPriceOfNormalMovie += 2; Console.WriteLine("$2"); } foreach (var movie in movies.Where( m => m == MovieType.Transfer)) { totalPriceOfTransferMovie += 3; Console.WriteLine("$3"); }
The above code will still do two things, calculate the total price, and display the price of the movie. I will not take care of it at the moment, instead I will focus on the enum and try to remove them. One way to remove enum is by using the Replace Conditional with Polymorphism. So I will create two classes, one base class called Movie, and one called MovieTransfer. The Movie class will have a property called Price, the Movie will now hold the price:
public class Movie { public virtual int Price { get { return 2; } } } public class MovieTransfer : Movie { public override int Price { get { return 3; } } }
The following code has no enum and will use the new Movie classes instead:
class Program { static void Main(string[] args) { var movies = GetMovies(); int totalPriceOfNormalMovie = 0; int totalPriceOfTransferMovie = 0; foreach (var movie in movies.Where( m => !(m is MovieTransfer))) { totalPriceOfNormalMovie += movie.Price; Console.WriteLine(movie.Price); } foreach (var movie in movies.Where( m => m is MovieTransfer)) { totalPriceOfTransferMovie += movie.Price; Console.WriteLine(movie.Price); } } private static IEnumerable<Movie> GetMovies() { return new List<Movie>() { new Movie(), new MovieTransfer(), new Movie() }; } }
If you take a look at the foreach now, you can see it still actually do two things, calculate the price and display the price. We can do some more refactoring here by using the Sum extension method to calculate the total price of the movies:
static void Main(string[] args) { var movies = GetMovies(); int totalPriceOfNormalMovie = movies.Where(m => !(m is MovieTransfer)) .Sum(m => m.Price); int totalPriceOfTransferMovie = movies.Where(m => m is MovieTransfer) .Sum(m => m.Price); foreach (var movie in movies.Where( m => !(m is MovieTransfer))) Console.WriteLine(movie.Price); foreach (var movie in movies.Where( m => m is MovieTransfer)) Console.WriteLine(movie.Price); }
Now when the Movie object will hold the price, there is no need to use two separate foreach to display the price of the movies in the list, so we can use only one instead:
foreach (var movie in movies) Console.WriteLine(movie.Price);
If we want to increase the Maintainability index we can use the Extract Method to move the Sum of the prices into two separate methods. The name of the method will explain what we are doing:
static void Main(string[] args) { var movies = GetMovies(); int totalPriceOfMovie = TotalPriceOfMovie(movies); int totalPriceOfTransferMovie = TotalPriceOfMovieTransfer(movies); foreach (var movie in movies) Console.WriteLine(movie.Price); } private static int TotalPriceOfMovieTransfer(IEnumerable<Movie> movies) { return movies.Where(m => m is MovieTransfer) .Sum(m => m.Price); } private static int TotalPriceOfMovie(IEnumerable<Movie> movies) { return movies.Where(m => !(m is MovieTransfer)) .Sum(m => m.Price); }
Now to the last thing, I love the ForEach method of the List<T>, but the IEnumerable<T> doesn’t have it, so I created my own ForEach extension, here is the code of the ForEach extension method:
public static class LoopExtensions { public static void ForEach<T>(this IEnumerable<T> values, Action<T> action) { Contract.Requires(values != null); Contract.Requires(action != null); foreach (var v in values) action(v); } }
I will now replace the foreach by using this ForEach method:
static void Main(string[] args) { var movies = GetMovies(); int totalPriceOfMovie = TotalPriceOfMovie(movies); int totalPriceOfTransferMovie = TotalPriceOfMovieTransfer(movies); movies.ForEach(m => Console.WriteLine(m.Price)); }
The ForEach on the movies will now display the price of the movie, but maybe we want to display the name of the movie etc, so we can use Extract Method by moving the lamdba expression into a method instead, and let the method explains what we are displaying:
movies.ForEach(DisplayMovieInfo); private static void DisplayMovieInfo(Movie movie) { Console.WriteLine(movie.Price); }
Now the refactoring is done, or?
EDIT: It's not done. Niclas Nilsson wrote a good comment that I actually do two things in the TotalPrice... methods, a filtering and a Sum. So both a command and a query. so it will not follow the Command and Query Principle. To make the code more seperated by concerns:
The TotalPriceXXX method are removed and replaced with on called TotalPriceOfMovie. Two properties are added, one for returning the normal movies and one for returning the transfer movies. Here is the result of the new refactoring:
static void Main(string[] args) { int totalPriceOfMovie = TotalPriceOfMovie(NormalMovies); int totalPriceOfTransferMovie = TotalPriceOfMovie(MovieTransfers); foreach (var movie in movies) Console.WriteLine(movie.Price); } private static int TotalPriceOfMovie(IEnumerable<Movie> movies) { return movies.Sum(m => m.Price); } private static IEnumerable<Movie> NormalMovies { get { return AllMovies.Where(m => !(m is MovieTransfer)); }
}
private static IEnumerable<Movie> MovieTransfers
{ get { return AllMovies.Where(m => m is MovieTransfer); }
}
private static IEnumerable<Movie> AllMovies
{ get { return GetAllMovies(); }
}
private static IEnumerable<Movie> GetMovies() { return new List<Movie>() { new Movie(), new MovieTransfer(), new Movie() }; }
Note: The GetMovies() method in this code is just a fake method to return fake data, so it could instead be a call to a Repository etc.
I think the new code is much cleaner than the first one, and I love the ForEach extension on the IEnumerable<T>, I can use it for different kind of things, for example:
movies.Where(m => m is Movie)
.ForEach(DoSomething);
By using the Where and ForEach extension method, some if statements can be removed and will make the code much cleaner. But the beauty is in the eye of the beholder. What would you have done different, what do you think will make the first example in the blog post look much cleaner than my results, comments are welcome!
If you want to know when I will publish a new blog post, you can follow me on twitter: http://www.twitter.com/fredrikn