NovelEssay.com Programming Blog

Exploration of Big Data, Machine Learning, Natural Language Processing, and other fun problems.

Tesseract 4.0 C# .Net Wrapper Released!

This article is about the Tesseract 4.0 C# .Net Wrapper that is only a few days old as of April 2017.


You are probably familiar with the Tesseract 3.04 C# .Net Wrapper found here:

https://github.com/charlesw/tesseract

That is already available as a Nuget package and has many downloads.


Just about a week ago, an Alpha release of the Tesseract 4.0 C# .Net wrapper was published here:

https://github.com/tdhintz/tesseract4win64

This is an x64 only .Net assembly. 


Find the Tesseract 4.0 language packs here:

https://github.com/tesseract-ocr/tessdata

When I load English only language pack, it uses a reasonable 180MB of RAM. I tried to load "all languages", and it was using over 8GB of RAM. 


This build is incredibly slow for debug mode. It runs 5-8X slower in debug mode than release mode, so watch out for that.


Amazingly, the .Net wrapper API works exactly the same as the Tesseract C# .Net 3.0 wrapper! (When you read about how the engine changed a huge amount and using LTSM networks, this will be more amazing to you.)


A very simple usage example works like this:

var tessEngine = new TesseractEngine(tessdataPath, "eng");
using (Page page = tessEngine .Process(myImage))
{
    string resultText = page.GetText();


Be sure to drop these two files in your \bin\debug or \bin\release folder at a x64 sub-folder like this::

.\bin\release\x64\libtesseract400.dll
.\bin\release\x64\liblept1741.dll

When the Tesseract.dll 4.0 assembly loads, it needs to find those DLLs else it will throw an exception in your application.


There is a very nice Accuracy and Performance overview report of 3.04 versus 4.0 here:

https://github.com/tesseract-ocr/tesseract/wiki/4.0-Accuracy-and-Performance

I agree with it's findings generally, but my own personal tests are not nearly as "improved" versus 3.04. I have a regression test that contains about 2200 pages, and I'm observing plenty of slower and less precise OCR results with Tesseract 4.0. It is certainly not all "better and faster" as of April 2017. Since this is an extremely new Alpha release, I have high hopes that it will improve over time.


Installing Python Chainer and Theano on Windows with Anaconda for GPU Processing

Let's say you want to do some GPU processing on Windows and you want to use Python, because of awesome things like this:


We'll show the setup steps for installing Python Chainer and Theano on Windows 10 in this blog article.


Some Terms:

CUDAan API model created by Nvidia for GPU processing.

cuDNN - a neural network plugin library for CUDA

Chainer - a Python neural network framework package

Theano - a Python deep learning package


Initial Hardware and OS Requirements:

You need an Nvidia CUDA supported video card. (I have a NVidia GeForce GTX 750 Ti.) Check for your GPU card in the support list found here: https://developer.nvidia.com/cuda-gpus 

You need Windows 10. (Everything in this procedure is x64.)


Important: 

Versions matter a lot. I tried to do this exact same setup with Python 2.7, and I was not successful. I tried to do the same thing with Anaconda 2, and that didn't work. I tried to do this same thing with cuDNN 5.5, and that didn't work. - So many combinations didn't work for me that I decided to write about what did work.


Procedure:

1) Install Visual Studio 2015. You must install Visual Studio before installing the CUDA tool kit. You need the \bin\cl.exe compiler. I have the VS2015 Enterprise Edition, but the VS2015 Community Edition is free here: https://www.microsoft.com/en-us/download/details.aspx?id=48146


2) Install the CUDA Tool kit found here: https://developer.nvidia.com/cuda-downloads

That installs v8.0 to a path like this: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0


3) Download the cuDNN v5.0 here: https://developer.nvidia.com/cudnn

There is a v5.1 there, but it did not work for me. Feel free to try it, but I suggest trying v5.0 first.

The cuDNN is just 3 files. You'll want to drop them in the CUDA path:

  • Drop the cudnn.h file in the folder:  C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\include\
  • Drop the cudnn64_5.dll file in the folder:  C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\bin\
  • Drop the cudnn.lib file in the folder:  C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\lib\x64\

4) Install Anaconda 3.6 for Windows x64 found here: https://repo.continuum.io/archive/Anaconda3-4.3.1-Windows-x86_64.exe

In case that link breaks, this is the page I found it at: https://www.continuum.io/downloads

You'll be doing most of your Anaconda/Python work in the Anaconda Console window. If Windows does not give you a nice link to the Anaconda Console, make a short cut with a link that looks like this:

"%windir%\system32\cmd.exe " "/K" C:\ProgramData\Anaconda3\Scripts\activate.bat C:\ProgramData\Anaconda3

I installed Anaconda for "All Users", so it put it at ProgramData. If you install to just one user, it puts Anaconda at a c:\users\<your name>\ path.


5) Building python packages requires a gcc/g++ compiler. Install MinGW for x64 here: https://sourceforge.net/projects/mingw-w64/

WARNING: During this install, be sure to pick the x86_64 install and not the i686 install!

The default install for MinGW is at c:\Program Files\mingw-w64\x86_64-6.3.0-posix-seh-rt_v5-rev1\mingw64\bin

The space in Program Files will break stuff later, so move it to something like this instead:

C:\mingw-w64\x86_64-6.3.0-posix-seh-rt_v5-rev1\mingw64\bin


6) Environment paths! 

If you have no idea how to set Enviornment variables in Windows, here's a link that describes how to do that: http://www.computerhope.com/issues/ch000549.htm

Add a variable called "CFlags" with this value:

  • -IC:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\include

Add a variable called "CUDA_PATH" with this value:

  • C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0

Add a variable called "LD_LIBRARY_PATH" with this value:

  • C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\lib\x64

Add a variable called "LDFLAGS" with this value:

  • -LC:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\lib\x64

Add all of the following to your PATH variable (or ensure they exist):

  • C:\mingw-w64\x86_64-6.3.0-posix-seh-rt_v5-rev1\mingw64\bin
  • C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\bin
  • C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v8.0\libnvvp
  • C:\Program Files (x86)\Microsoft Visual Studio 14.0\vc\bin
  • C:\ProgramData\Anaconda3
  • C:\ProgramData\Anaconda3\Scripts
  • C:\ProgramData\Anaconda3\Library\bin

(The Anacond3 paths might get set automatically for you.)


7) Next, bring up your Anaconda console prompt and install some packages. Type the following lines:

pip install Pillow
pip install Pycuda
pip install Theano
pip install Chainer
If any of those fail to install, stop and figure out why. Append a -vvvv to the end of the install lines to get a very-very-very verbose dump of the install process. 

Note: If you can't get pycuda to install due to "missing stdlib.h" errors, you can get the pycuda Whl file and install that directly instead.

It likely is because one of your steps #1-6 isn't quite right, or because your GCC compiler is trying to use an old x32 version that you installed long ago. (That was the case for me. I had Cygwin and a x32 GCC compiler that caused failing pip package installs.)

I also had some build fails on Chainer with some errors about "_hypot" being undefined. I fixed those by going to C:\ProgramData\Anaconda3\include\pyconfig.h, and commenting out the two places in that file that do this:
//#define hypot _hypot
That appear to have fixed that issue for me, but there's probably a better solution.

8) Sanity checks and smoke tests:
First, try to import the packages from a python command window. You can run this directly from your Anaconda console like this:
  • python -c "import theano"
  • python -c "import chainer"
  • python -c "import cupy"
If one of them fails, identify the error message and ask the Google about it. They should all work:


A last smoke test is to get the "Hello GPU" test code from here:

Here's a copy:
import pycuda.autoinit
import pycuda.driver as drv
import numpy
from pycuda.compiler import SourceModule
mod = SourceModule("""
__global__ void multiply_them(float *dest, float *a, float *b)
{
  const int i = threadIdx.x;
  dest[i] = a[i] * b[i];
}
""")
multiply_them = mod.get_function("multiply_them")
a = numpy.random.randn(400).astype(numpy.float32)
b = numpy.random.randn(400).astype(numpy.float32)
dest = numpy.zeros_like(a)
multiply_them(
        drv.Out(dest), drv.In(a), drv.In(b),
        block=(400,1,1), grid=(1,1))
print (dest-a*b)

I had to change the last line of that code to have parenthesis around it like this:
print (dest-a*b)

When you run that with a command like this:
python pycuda_test.py
You should get an output of 0's that look like this:



Conclusion:
If you've gotten to here, congratulations! Your Windows 10 environment should be all setup to run Python GPU processing.

My GPU has been running for days at 90%, and my CPU is free for other work. This was a seriously miserable to figure out, but now it feels like my computer doubled the processing power!

Enjoy the awesome:

Machine Learning for Network Security with C# and Vowpal Wabbit


This article will discuss a solution to the KDD99 Network Security contest using C# and the Vowpal Wabbit machine learning library. This is a very quick and naive approach that yields 97%+ accuracy. Proper data preprocessing and better feature selection should result in much better predictions. Details about that Network Security contest can be found here:

http://kdd.ics.uci.edu/databases/kddcup99/kddcup99.html


Start by creating a new Visual Studio C# project and import the Vowpal Wabbit Nuget package.


Then, we build a class that describes the data records and the target labels that we have for training.

    public class DataRecord : VWRecord
    {
        public VWRecord GetVWRecord()
        {
            return (VWRecord) this;
        }

        public string label { get; set; }

        public int labelInt { get; set; }

        public bool isKnownAttackType { get; set; }
    }


Create a VWRecord class with Vowpal Wabbit annotations. This is used by the VW to map features to the correct format. In this example, I set Enumerize=true as much as I can and lump the features in to a single feature group. I didn't try splitting up the features in to different groups, but that seems like a smart and reasonable thing to explore.

    public class VWRecord
    {
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float duration { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public string protocol_type { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public string service { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public string flag { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float src_bytes { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_bytes { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float land { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float wrong_fragment { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float urgent { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float hot { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float num_failed_logins { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float logged_in { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float num_compromised { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float root_shell { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float su_attempted { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float num_root { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float num_file_creations { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float num_shells { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float num_access_files { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float num_outbound_cmds { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float is_host_login { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float is_guest_login { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float count { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float srv_count { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float serror_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float srv_serror_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float rerror_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float srv_rerror_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float same_srv_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float diff_srv_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float srv_diff_host_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_host_count { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_host_srv_count { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_host_same_srv_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_host_diff_srv_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_host_same_src_port_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_host_srv_diff_host_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_host_serror_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_host_srv_serror_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_host_rerror_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_host_srv_rerror_rate { get; set; }

    }


Next, we create a Vowpal Wabbit wrapper class. After the VWWrapper instantiation, call Init before any calls to Train or Predict.

    public class VWWrapper
    {
        VW.VowpalWabbit<VWRecord> vw = null;

        public void Init(bool train = true)
        {
            vw = new VW.VowpalWabbit<VWRecord>(new VowpalWabbitSettings
            {
                EnableStringExampleGeneration = true,
                Verbose = true,
                Arguments = string.Join(" "
                , "-f vw.model"
                , "--progress 100000"
                , "-b 27"
                )
            });

        }

        public void Train(DataRecord record)
        {
            VWRecord vwRecord = record.GetVWRecord();
            SimpleLabel label = new SimpleLabel() { Label = record.labelInt };
            vw.Learn(vwRecord, label);
        }

        public float Predict(VWRecord record)
        {
            return vw.Predict(record, VowpalWabbitPredictionType.Scalar);
        }

The whole program is just a few lines. Open a data source to the training or test data set. Loop through the records and call train or predict. On the predictions, compare the prediction against the actual label and score appropriately.

Notice the cutoff score of 0.4 in the prediction function. Vowpal Wabbit will give a prediction between 0 and 1. You can tune your cutoff score to meet whatever precision and recall behaviors suite your needs. A higher cutoff score will result in some more "attack" records being predicted as "normal".

        static string trainingSource = @"C:\kaggle\kdd99\kddcup.data.corrected";
        static string testSource = @"C:\kaggle\kdd99\corrected";

        static VWWrapper vw = new VWWrapper();

        static int trainingRecordCount = 0;
        static int testRecordCount = 0;
        static int evaluateRecordCount = 0;

        static int correctNormal = 0;
        static int correctAttack = 0;
        static int totalNormal = 0;
        static int totalAttack = 0;

        static void Main(string[] args)
        {
            Stopwatch swTotal = new Stopwatch();
            swTotal.Start();
            vw.Init();
            DoTraining();
            DoEvaluate();
            swTotal.Stop();
            Console.WriteLine("Done. ElapsedTime: " + swTotal.Elapsed);
        }


        static void DoEvaluate()
        {
            float cutoffScore = 0.4f;

            DataSource sourceEval = new DataSource(testSource);
            while (sourceEval.NextRecord())
            {
                if(sourceEval.Record.isKnownAttackType)
                {
                    float prediction = vw.Predict(sourceEval.Record.GetVWRecord());

                    if(sourceEval.Record.labelInt == 0)
                    {
                        totalNormal++;
                        if (prediction < cutoffScore) correctNormal++;
                    }
                    if(sourceEval.Record.labelInt == 1)
                    {
                        totalAttack++;
                        if (prediction >= cutoffScore) correctAttack++;
                    }
                    evaluateRecordCount++;
                }
            }

            Console.WriteLine("Evaluate Complete. evaluateRecordCount = " + evaluateRecordCount);
            Console.WriteLine("Evaluate totalNormal = " + totalNormal + " correctNormal = " + correctNormal);
            Console.WriteLine("Evaluate totalAttack = " + totalAttack + " correctAttack = " + correctAttack);
            Console.WriteLine("Evaluate DONE!");
        }

        static void DoTraining()
        {
            DataSource source = new DataSource(trainingSource);
            while (source.NextRecord())
            {
                vw.Train(source.Record);
            }
            Console.WriteLine("Train Complete. trainingRecordCount = " + trainingRecordCount);
        }


This solution was incredibly quick and easy to implement and yields a 99.6% correct prediction of normal records, and about 97% correct prediction of attack records.


Finally, here is the data source example so you can cut & paste to try this out yourself:

    public class DataSource
    {
        // Columns
        private static int COLUMN_COUNT = 42;
        private static int COLUMN_COUNT_TEST = 42;

        // Current Record Attributes
        public DataRecord Record = new DataRecord();


        private string sourceReport;

        private System.IO.StreamReader fileReader;
        private int sourceIndex;

        public DataSource(string sourceReport)
        {
            this.sourceReport = sourceReport;
            Reset();
        }

        public bool NextRecord()
        {
            bool foundRecord = false;
            while (!fileReader.EndOfStream)
            {
                try
                {
                    //Processing row
                    string line = fileReader.ReadLine();
                    string[] fields = line.TrimEnd('.').Split(',');

                    // Expect COLUMN_COUNT columns
                    if (fields.Count() != COLUMN_COUNT && fields.Count() != COLUMN_COUNT_TEST)
                    {
                        throw new Exception(string.Format("sourceReportParser column count [{0}] != expected COLUMN_COUNT [{1}]", fields.Count(), COLUMN_COUNT));
                    }

                    Record = new DataRecord();

                    if (fields.Count() == COLUMN_COUNT)
                    {
                        Record.duration = float.Parse(fields[0]);
                        Record.protocol_type = fields[1];
                        Record.service = fields[2];
                        Record.flag = fields[3];
                        Record.src_bytes = float.Parse(fields[4]);
                        Record.dst_bytes = float.Parse(fields[5]);
                        Record.land = float.Parse(fields[6]);
                        Record.wrong_fragment = float.Parse(fields[7]);
                        Record.urgent = float.Parse(fields[8]);
                        Record.hot = float.Parse(fields[9]);
                        Record.num_failed_logins = float.Parse(fields[10]);
                        Record.logged_in = float.Parse(fields[11]);
                        Record.num_compromised = float.Parse(fields[12]);
                        Record.root_shell = float.Parse(fields[13]);
                        Record.su_attempted = float.Parse(fields[14]);
                        Record.num_root = float.Parse(fields[15]);
                        Record.num_file_creations = float.Parse(fields[16]);
                        Record.num_shells = float.Parse(fields[17]);
                        Record.num_access_files = float.Parse(fields[18]);
                        Record.num_outbound_cmds = float.Parse(fields[19]);
                        Record.is_host_login = float.Parse(fields[20]);
                        Record.is_guest_login = float.Parse(fields[21]);
                        Record.count = float.Parse(fields[22]);
                        Record.srv_count = float.Parse(fields[23]);
                        Record.serror_rate = float.Parse(fields[24]);
                        Record.srv_serror_rate = float.Parse(fields[25]);
                        Record.rerror_rate = float.Parse(fields[26]);
                        Record.srv_rerror_rate = float.Parse(fields[27]);
                        Record.same_srv_rate = float.Parse(fields[28]);
                        Record.diff_srv_rate = float.Parse(fields[29]);
                        Record.srv_diff_host_rate = float.Parse(fields[30]);
                        Record.dst_host_count = float.Parse(fields[31]);
                        Record.dst_host_srv_count = float.Parse(fields[32]);
                        Record.dst_host_same_srv_rate = float.Parse(fields[33]);
                        Record.dst_host_diff_srv_rate = float.Parse(fields[34]);
                        Record.dst_host_same_src_port_rate = float.Parse(fields[35]);
                        Record.dst_host_srv_diff_host_rate = float.Parse(fields[36]);
                        Record.dst_host_serror_rate = float.Parse(fields[37]);
                        Record.dst_host_srv_serror_rate = float.Parse(fields[38]);
                        Record.dst_host_rerror_rate = float.Parse(fields[39]);
                        Record.dst_host_srv_rerror_rate = float.Parse(fields[40]);

                        Record.label = fields[41];
                        Record.isKnownAttackType = true;

                        switch (Record.label)
                        {
                            case "buffer_overflow":
                                Record.labelInt = 1;
                                break;
                            case "ftp_write":
                                Record.labelInt = 2;
                                break;
                            case "guess_passwd":
                                Record.labelInt = 3;
                                break;
                            case "imap":
                                Record.labelInt = 4;
                                break;
                            case "ipsweep":
                                Record.labelInt = 5;
                                break;
                            case "land":
                                Record.labelInt = 6;
                                break;
                            case "loadmodule":
                                Record.labelInt = 7;
                                break;
                            case "multihop":
                                Record.labelInt = 8;
                                break;
                            case "neptune":
                                Record.labelInt = 9;
                                break;
                            case "nmap":
                                Record.labelInt = 10;
                                break;
                            case "normal":
                                Record.labelInt = 11;
                                break;
                            case "perl":
                                Record.labelInt = 12;
                                break;
                            case "phf":
                                Record.labelInt = 13;
                                break;
                            case "pod":
                                Record.labelInt = 14;
                                break;
                            case "portsweep":
                                Record.labelInt = 15;
                                break;
                            case "rootkit":
                                Record.labelInt = 16;
                                break;
                            case "satan":
                                Record.labelInt = 17;
                                break;
                            case "smurf":
                                Record.labelInt = 18;
                                break;
                            case "spy":
                                Record.labelInt = 19;
                                break;
                            case "teardrop":
                                Record.labelInt = 20;
                                break;
                            case "warezclient":
                                Record.labelInt = 21;
                                break;
                            case "warezmaster":
                                Record.labelInt = 22;
                                break;
                            case "back":
                                Record.labelInt = 23;
                                break;
                            default:
                                //Console.WriteLine("ERROR: Invalid Label Type");
                                Record.isKnownAttackType = false;
                                break;
                        }

                        if(Record.label == "normal")
                        {
                            Record.labelInt = 0;
                        }
                        else
                        {
                            Record.labelInt = 1;
                        }
                    }
                    else
                    {


                    }

                    sourceIndex++;

                    // Getting here means we have a good record. Break the loop.
                    foundRecord = true;
                    break;
                }
                catch (Exception e)
                {
                    Console.WriteLine("ERROR: NextRecord failed for line: " + sourceIndex + " with exception: " + e.Message + " Stack: " + e.StackTrace);
                    sourceIndex++;
                }
            }
            return foundRecord;
        }

        public void Reset()
        {
            fileReader = new System.IO.StreamReader(sourceReport);
            // Burn column headers
            string line = fileReader.ReadLine();
            string[] fields = line.Split(',');
            if (fields.Count() != COLUMN_COUNT && fields.Count() != COLUMN_COUNT_TEST)
            {
                throw new Exception(string.Format("sourceReportParser column count [{0}] != expected COLUMN_COUNT [{1}]", fields.Count(), COLUMN_COUNT));
            }
            sourceIndex = 0;
        }




Vowpal Wabbit C# solution to Kaggle competition: Grupo Bimbo Inventory Demand

Kaggle Competition: Grupo Bimbo Inventory Demand

grupo-bimbo-inventory-demand

Vowpal Wabbit C# solution summary:

  1. Download data sets from the Kaggle competition site.
  2. Create a new Visual Studio 2015 solution and project.
  3. Install Vowpal Wabbit Nuget package to your Visual Studio project.
  4. Create a DataSource class to read the data set files.
  5. Create a VWRecord class with annotated properties.
  6. Create a VWWrapper class to manage the Vowpal Wabbit engine instance, training, and prediction API calls.
  7. Create a test program to manage training, prediction, evaluation jobs.
  8. Run test program and tinker until satisfied.


Download data sets from the Kaggle competition site.

Specifically here: grupo-bimbo-inventory-demand/data

You don't necessarily need to download the data to understand this article, but some of the data fields may be referenced later in this article. Review the field detail descriptions to give yourself some context.


Create a new Visual Studio 2015 solution and project.

I created a C# console application in Visual Studio to run my test projects. Nothing fancy about that.


Install Vowpal Wabbit Nuget package to your Visual Studio project.

Use the Nuget package manager in Visual Studio. Search for vowpal wabbit. Install the latest version.


Create a DataSource class to read the data set files.

This class opens the source (or test file) and reads each line. Each line is split in to fields and mapped to properties by its column position. The number of columns viewed tells this class how to map fields. The training data has 11 columns. Test test data has 7 columns. Correct - this is not fancy, but it works.

    public class DataSource
    {
        // Columns
        private static int COLUMN_COUNT = 11;
        private static int COLUMN_COUNT_TEST = 7;

        // Current Record Attributes
        public DataRecord Record = new DataRecord();


        private string sourceReport;

        private System.IO.StreamReader fileReader;
        private int sourceIndex;

        public DataSource(string sourceReport)
        {
            this.sourceReport = sourceReport;
            Reset();
        }

        public bool NextRecord()
        {
            bool foundRecord = false;
            while (!fileReader.EndOfStream)
            {
                try
                {
                    //Processing row
                    string line = fileReader.ReadLine();
                    string[] fields = line.Split(',');

                    // Expect COLUMN_COUNT columns
                    if (fields.Count() != COLUMN_COUNT && fields.Count() != COLUMN_COUNT_TEST)
                    {
                        throw new Exception(string.Format("sourceReportParser column count [{0}] != expected COLUMN_COUNT [{1}]", fields.Count(), COLUMN_COUNT));
                    }

                    Record = new DataRecord();

                    if (fields.Count() == COLUMN_COUNT)
                    {
                        // Semana,Agencia_ID,Canal_ID,Ruta_SAK,Cliente_ID,Producto_ID,Venta_uni_hoy,Venta_hoy,Dev_uni_proxima,Dev_proxima,Demanda_uni_equil
                        Record.WeekId = fields[0];
                        Record.SalesDepotID = int.Parse(fields[1]);
                        Record.SalesChannelID = int.Parse(fields[2]);
                        Record.RouteID = int.Parse(fields[3]);
                        Record.Cliente_ID = int.Parse(fields[4]);
                        Record.Producto_ID = int.Parse(fields[5]);
                        Record.Venta_uni_hoy = fields[6];
                        Record.Venta_hoy = fields[7];
                        Record.Dev_uni_proxima = fields[8];
                        Record.Dev_proxima = fields[9];
                        Record.Demanda_uni_equil = fields[10];
                    }
                    else
                    {
                        //id,Semana,Agencia_ID,Canal_ID,Ruta_SAK,Cliente_ID,Producto_ID
                        Record.Id = fields[0];
                        Record.WeekId = fields[1];
                        Record.SalesDepotID = int.Parse(fields[2]);
                        Record.SalesChannelID = int.Parse(fields[3]);
                        Record.RouteID = int.Parse(fields[4]);
                        Record.Cliente_ID = int.Parse(fields[5]);
                        Record.Producto_ID = int.Parse(fields[6]);
                    }

                    sourceIndex++;

                    // Getting here means we have a good record. Break the loop.
                    foundRecord = true;
                    break;
                }
                catch (Exception e)
                {
                    Console.WriteLine("ERROR: NextRecord failed for line: " + sourceIndex + " with exception: " + e.Message + " Stack: " + e.StackTrace);
                    sourceIndex++;
                }
            }
            return foundRecord;
        }

        public void Reset()
        {
            fileReader = new System.IO.StreamReader(sourceReport);
            // Burn column headers
            string line = fileReader.ReadLine();
            string[] fields = line.Split(',');
            if (fields.Count() != COLUMN_COUNT && fields.Count() != COLUMN_COUNT_TEST)
            {
                throw new Exception(string.Format("sourceReportParser column count [{0}] != expected COLUMN_COUNT [{1}]", fields.Count(), COLUMN_COUNT));
            }
            sourceIndex = 0;
        }

    }


Create a VWRecord class with annotated properties.

This class has a property for each of the fields in the test data except the Week ID. I observed the Week ID was best for splitting the training data up rather than using it as a feature for training or prediction. Notice the annotations on each property has a separate FeatureGroup. I tinkered with putting the properties in the same feature groups, but that yielded worse results. Notice the Enumerize=true annotations. The properties are integers, but they are ID values hence the Enumerize = true setup.

    public class VWRecord
    {        
        [Feature(FeatureGroup = 'h', Enumerize = true)]
        public int SalesDepotID { get; set; }

        [Feature(FeatureGroup = 'i', Enumerize = true)]
        public int SalesChannelID { get; set; }

        [Feature(FeatureGroup = 'j', Enumerize = true)]
        public int RouteID { get; set; }


        [Feature(FeatureGroup = 'k', Enumerize = true)]
        public int Cliente_ID { get; set; }

        [Feature(FeatureGroup = 'l', Enumerize = true)]
        public int Producto_ID { get; set; }
    }

Create a VWWrapper class to manage the Vowpal Wabbit engine instance, training, and prediction API calls.

The VW instance is created with a call to Init. Notice several quadratic and cubic feature spaces specified. I originally tried using --loss_function=squared but found much better results with --loss_function=quantile. Notice the commented out lines in the Train function that allow you to view the serialized VW intput strings. I found that useful in debugging and sanity checking.

    public class VWWrapper
    {
        VW.VowpalWabbit<VWRecord> vw = null;

        public void Init(bool train = true)
        {
            vw = new VW.VowpalWabbit<VWRecord>(new VowpalWabbitSettings {
                EnableStringExampleGeneration = true,
                Verbose = true,
                Arguments = string.Join(" "
                , "-f vw.model"
                //, "--loss_function=squared"
                , "--loss_function=quantile"
                , "--progress 100000"
                //, "--passes 2"
                //, "--cache_file vw.cache"
                , "-b 27"

                , "-q lk"
                , "-q lj"
                , "-q li"
                , "-q lh"

                , "--cubic hkl"
                , "--cubic ikl"
                , "--cubic jkl"

                , "--cubic hjl"
                //, "--cubic hil"
                , "--cubic hij"
                )
            });
            
        }

        public void Train(VWRecord record, float label)
        {
            // Comment this in if you want to see the VW serialized input records:
            //var str = vw.Serializer.Create(vw.Native).SerializeToString(record, new SimpleLabel() { Label = label });
            //Console.WriteLine(str);
            vw.Learn(record, new SimpleLabel() { Label = label });
        }

        public void SaveModel()
        {
            vw.Native.SaveModel();
        }

        public float Predict(VWRecord record)
        {
            return vw.Predict(record, VowpalWabbitPredictionType.Scalar);
        }

Create a test program to manage training, prediction, evaluation jobs.

If you created a Console Application (like I did), then you can set your Main up like this. There are two primary modes: train/evaluate and train/predict. The train & evaluate will let you measure your model against labeled data not used during the training. Vowpal Wabbit does a good job of calculating average loss on its own.


Notice the "Median" collections. When VW predicts a 0 label, we're going to ignore that and use some Median values instead.

        static string trainingSource = @"C:\kaggle\GrupoBimboInventoryDemand\train.csv";
        static string testSource = @"C:\kaggle\GrupoBimboInventoryDemand\test.csv";
        static string outputPredictions = @"C:\kaggle\GrupoBimboInventoryDemand\myPredictions.csv";

        static VWWrapper vw = new VWWrapper();

        static int trainingRecordCount = 0;
        static int testRecordCount = 0;
        static int evaluateRecordCount = 0;

        static Dictionary<int, float> productDemand = new Dictionary<int, float>();
        static Dictionary<int, int> productRecords = new Dictionary<int, int>();
        static Dictionary<int, List<float>> productDemandInstances = new Dictionary<int, List<float>>();
        static List<float> demandInstances = new List<float>();
        static Int64 totalDemand = 0;

        static Dictionary<int, float> productMedians = new Dictionary<int, float>();
        static float allProductsMedian = 0;

        static void Main(string[] args)
        {
            Stopwatch swTotal = new Stopwatch();
            swTotal.Start();

            vw.Init();

            bool testOnly = true;
            if(testOnly)
            {
                DoTraining(true);
                DoEvaluate();
            }
            else
            {
                DoTraining();
                DoPredictions();
            }

            // Save model here if you want to load and reuse it.
            // Current test app doesn't ever load/reuse it.
            //vw.SaveModel();

            swTotal.Stop();
            Console.WriteLine("Done. ElapsedTime: " + swTotal.Elapsed);
        }
        static float Median(float[] xs)
        {
            var ys = xs.OrderBy(x => x).ToList();
            double mid = (ys.Count - 1) / 2.0;
            return (ys[(int)(mid)] + ys[(int)(mid + 0.5)]) / 2;
        }


The DoTraining function loops through the DataSource records and maps fields in to a VWRecord. The VWRecord is sent to the Train function along with the known label. Each record's label is saved in to product demand and instance collections, so we can calculate Median demand values later. If the skipEightNine is set to true, then the training data for weeks 8 & 9 will not be used for training. Week 8 & 9 data is later used for model Evaluation.

        static void DoTraining(bool skipEightNine = false)
        {
            DataSource source = new DataSource(trainingSource);
            while (source.NextRecord())
            {
                if (skipEightNine && (source.Record.WeekId == "8" || source.Record.WeekId == "9")) continue;

                VWRecord vwRecord = new VWRecord();
                vwRecord.Cliente_ID = source.Record.Cliente_ID;
                vwRecord.Producto_ID = source.Record.Producto_ID;
                vwRecord.RouteID = source.Record.RouteID;
                vwRecord.SalesChannelID = source.Record.SalesChannelID;
                vwRecord.SalesDepotID = source.Record.SalesDepotID;

                float label = float.Parse(source.Record.Demanda_uni_equil);

                demandInstances.Add(label);
                totalDemand += int.Parse(label.ToString());
                if (productDemand.ContainsKey(source.Record.Producto_ID))
                {
                    productDemand[source.Record.Producto_ID] += label;
                    productRecords[source.Record.Producto_ID]++;
                    productDemandInstances[source.Record.Producto_ID].Add(label);
                }
                else
                {
                    productDemand.Add(source.Record.Producto_ID, label);
                    productRecords.Add(source.Record.Producto_ID, 1);
                    productDemandInstances.Add(source.Record.Producto_ID, new List<float>() { label });
                }

                trainingRecordCount++;

                vw.Train(vwRecord, label);
            }

            // Calculate medians
            foreach(var product in productDemandInstances.Keys)
            {
                productMedians.Add(product, Median(productDemandInstances[product].ToArray()));
            }
            allProductsMedian = Median(demandInstances.ToArray());
            // end median calculations

            Console.WriteLine("Train Complete. trainingRecordCount = " + trainingRecordCount);
        }


The DoEvaluate function loops through the DataSource records and only processes the Week 8 & 9 records. Each of those records is mapped to a VWRecord class. Then, a Vowpal Wabbit prediction is made on the data. If the prediction is 0, then we use the Medians values as predictions instead.


The difference between the predicted value and the actual value is calculated and reported as the RMSLE value. RMSLE is the Root Mean Squared Logarithmic Error which is used as the competition evaluation measurement:

https://www.kaggle.com/wiki/RootMeanSquaredLogarithmicError

        static void DoEvaluate()
        {
            double logDiffSquaredSum = 0;

            DataSource sourceEval = new DataSource(trainingSource);
            while (sourceEval.NextRecord())
            {
                if (sourceEval.Record.WeekId != "8" && sourceEval.Record.WeekId != "9") continue;

                VWRecord vwRecord = new VWRecord();

                vwRecord.Cliente_ID = sourceEval.Record.Cliente_ID;
                vwRecord.Producto_ID = sourceEval.Record.Producto_ID;
                vwRecord.RouteID = sourceEval.Record.RouteID;
                vwRecord.SalesChannelID = sourceEval.Record.SalesChannelID;
                vwRecord.SalesDepotID = sourceEval.Record.SalesDepotID;

                float actualLabel = float.Parse(sourceEval.Record.Demanda_uni_equil);

                float prediction = vw.Predict(vwRecord);

                if (prediction == 0)
                {
                    if (productDemand.ContainsKey(sourceEval.Record.Producto_ID))
                    {
                        prediction = productMedians[sourceEval.Record.Producto_ID];
                    }
                    else
                    {
                        prediction = allProductsMedian;
                    }
                }

                double logDiff = Math.Log(prediction + 1) - Math.Log(actualLabel + 1);
                logDiffSquaredSum += Math.Pow(logDiff, 2);


                evaluateRecordCount++;
            }
             
            Console.WriteLine("Evaluate Complete. evaluateRecordCount = " + evaluateRecordCount);
            Console.WriteLine("Evaluate RMSLE = " + Math.Pow(logDiffSquaredSum / evaluateRecordCount, 0.5));
            Console.WriteLine("Evaluate DONE!");
        }


After you have a model you like from the train & evaluate process, you will want to run train & predict. The DoPredictions function sets up the test data set as a DataSource, loops through the records, and calls predict for each record. Whenever the prediction is 0, the Median values are used as the prediction instead. The final output is written to a file with two columns (the test record ID and the predicted value).

        static void DoPredictions()
        {
            System.IO.File.WriteAllText(outputPredictions, "id,Demanda_uni_equil\n");

            StringBuilder predictionSB = new StringBuilder();

            DataSource sourceTest = new DataSource(testSource);
            while (sourceTest.NextRecord())
            {
                VWRecord vwRecord = new VWRecord();

                vwRecord.Cliente_ID = sourceTest.Record.Cliente_ID;
                vwRecord.Producto_ID = sourceTest.Record.Producto_ID;
                vwRecord.RouteID = sourceTest.Record.RouteID;
                vwRecord.SalesChannelID = sourceTest.Record.SalesChannelID;
                vwRecord.SalesDepotID = sourceTest.Record.SalesDepotID;

                testRecordCount++;
                if (testRecordCount % 100000 == 0)
                {
                    Console.WriteLine("sourceTest recordCount = " + testRecordCount);
                }

                float prediction = vw.Predict(vwRecord);

                if (prediction == 0)
                {
                    if (productDemand.ContainsKey(sourceTest.Record.Producto_ID))
                    {
                        prediction = productMedians[sourceTest.Record.Producto_ID];
                    }
                    else
                    {
                        prediction = allProductsMedian;
                    }
                }

                string outputLine = sourceTest.Record.Id + "," + prediction + "\n";
                predictionSB.Append(outputLine);
            }

            System.IO.File.AppendAllText(outputPredictions, predictionSB.ToString());

            Console.WriteLine("Predict Complete. recordCount =" + testRecordCount);
        }


Run test program and tinker until satisfied.

If you run this program exactly as it is, you'll get a RMSLE score of about 0.548 for train & evaluate. This yields a Kaggle competition score of 0.532. You can tinker with the features, feature spaces, cubic/quadratic feature, and other Vowpal Wabbit engine configurations to try and improve your scores. I'm sure there are many ways to improve this particular solution, but hopefully this provides you with a good framework to use.