Machine Learning for Network Security with C# and Vowpal Wabbit

This article will discuss a solution to the KDD99 Network Security contest using C# and the Vowpal Wabbit machine learning library. This is a very quick and naive approach that yields 97%+ accuracy. Proper data preprocessing and better feature selection should result in much better predictions. Details about that Network Security contest can be found here:

http://kdd.ics.uci.edu/databases/kddcup99/kddcup99.html

Start by creating a new Visual Studio C# project and import the Vowpal Wabbit Nuget package.

Then, we build a class that describes the data records and the target labels that we have for training.

    public class DataRecord : VWRecord
    {
        public VWRecord GetVWRecord()
        {
            return (VWRecord) this;
        }

        public string label { get; set; }

        public int labelInt { get; set; }

        public bool isKnownAttackType { get; set; }
    }

Create a VWRecord class with Vowpal Wabbit annotations. This is used by the VW to map features to the correct format. In this example, I set Enumerize=true as much as I can and lump the features in to a single feature group. I didn’t try splitting up the features in to different groups, but that seems like a smart and reasonable thing to explore.

    public class VWRecord
    {
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float duration { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public string protocol_type { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public string service { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public string flag { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float src_bytes { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_bytes { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float land { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float wrong_fragment { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float urgent { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float hot { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float num_failed_logins { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float logged_in { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float num_compromised { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float root_shell { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float su_attempted { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float num_root { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float num_file_creations { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float num_shells { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float num_access_files { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float num_outbound_cmds { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float is_host_login { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float is_guest_login { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float count { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float srv_count { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float serror_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float srv_serror_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float rerror_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float srv_rerror_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float same_srv_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float diff_srv_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float srv_diff_host_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_host_count { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_host_srv_count { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_host_same_srv_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_host_diff_srv_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_host_same_src_port_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_host_srv_diff_host_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_host_serror_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_host_srv_serror_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_host_rerror_rate { get; set; }
        [Feature(FeatureGroup = 'a', Enumerize = true)]
        public float dst_host_srv_rerror_rate { get; set; }

    }

Next, we create a Vowpal Wabbit wrapper class. After the VWWrapper instantiation, call Init before any calls to Train or Predict.

    public class VWWrapper
    {
        VW.VowpalWabbit<VWRecord> vw = null;

        public void Init(bool train = true)
        {
            vw = new VW.VowpalWabbit<VWRecord>(new VowpalWabbitSettings
            {
                EnableStringExampleGeneration = true,
                Verbose = true,
                Arguments = string.Join(" "
                , "-f vw.model"
                , "--progress 100000"
                , "-b 27"
                )
            });

        }

        public void Train(DataRecord record)
        {
            VWRecord vwRecord = record.GetVWRecord();
            SimpleLabel label = new SimpleLabel() { Label = record.labelInt };
            vw.Learn(vwRecord, label);
        }

        public float Predict(VWRecord record)
        {
            return vw.Predict(record, VowpalWabbitPredictionType.Scalar);
        }

The whole program is just a few lines. Open a data source to the training or test data set. Loop through the records and call train or predict. On the predictions, compare the prediction against the actual label and score appropriately.

Notice the cutoff score of 0.4 in the prediction function. Vowpal Wabbit will give a prediction between 0 and 1. You can tune your cutoff score to meet whatever precision and recall behaviors suite your needs. A higher cutoff score will result in some more “attack” records being predicted as “normal”.

        static string trainingSource = @"C:\kaggle\kdd99\kddcup.data.corrected";
        static string testSource = @"C:\kaggle\kdd99\corrected";

        static VWWrapper vw = new VWWrapper();

        static int trainingRecordCount = 0;
        static int testRecordCount = 0;
        static int evaluateRecordCount = 0;

        static int correctNormal = 0;
        static int correctAttack = 0;
        static int totalNormal = 0;
        static int totalAttack = 0;

        static void Main(string[] args)
        {
            Stopwatch swTotal = new Stopwatch();
            swTotal.Start();
            vw.Init();
            DoTraining();
            DoEvaluate();
            swTotal.Stop();
            Console.WriteLine("Done. ElapsedTime: " + swTotal.Elapsed);
        }


        static void DoEvaluate()
        {
            float cutoffScore = 0.4f;

            DataSource sourceEval = new DataSource(testSource);
            while (sourceEval.NextRecord())
            {
                if(sourceEval.Record.isKnownAttackType)
                {
                    float prediction = vw.Predict(sourceEval.Record.GetVWRecord());

                    if(sourceEval.Record.labelInt == 0)
                    {
                        totalNormal++;
                        if (prediction < cutoffScore) correctNormal++;
                    }
                    if(sourceEval.Record.labelInt == 1)
                    {
                        totalAttack++;
                        if (prediction >= cutoffScore) correctAttack++;
                    }
                    evaluateRecordCount++;
                }
            }

            Console.WriteLine("Evaluate Complete. evaluateRecordCount = " + evaluateRecordCount);
            Console.WriteLine("Evaluate totalNormal = " + totalNormal + " correctNormal = " + correctNormal);
            Console.WriteLine("Evaluate totalAttack = " + totalAttack + " correctAttack = " + correctAttack);
            Console.WriteLine("Evaluate DONE!");
        }

        static void DoTraining()
        {
            DataSource source = new DataSource(trainingSource);
            while (source.NextRecord())
            {
                vw.Train(source.Record);
            }
            Console.WriteLine("Train Complete. trainingRecordCount = " + trainingRecordCount);
        }

This solution was incredibly quick and easy to implement and yields a 99.6% correct prediction of normal records, and about 97% correct prediction of attack records.

Finally, here is the data source example so you can cut & paste to try this out yourself:

    public class DataSource
    {
        // Columns
        private static int COLUMN_COUNT = 42;
        private static int COLUMN_COUNT_TEST = 42;

        // Current Record Attributes
        public DataRecord Record = new DataRecord();


        private string sourceReport;

        private System.IO.StreamReader fileReader;
        private int sourceIndex;

        public DataSource(string sourceReport)
        {
            this.sourceReport = sourceReport;
            Reset();
        }

        public bool NextRecord()
        {
            bool foundRecord = false;
            while (!fileReader.EndOfStream)
            {
                try
                {
                    //Processing row
                    string line = fileReader.ReadLine();
                    string[] fields = line.TrimEnd('.').Split(',');

                    // Expect COLUMN_COUNT columns
                    if (fields.Count() != COLUMN_COUNT && fields.Count() != COLUMN_COUNT_TEST)
                    {
                        throw new Exception(string.Format("sourceReportParser column count [{0}] != expected COLUMN_COUNT [{1}]", fields.Count(), COLUMN_COUNT));
                    }

                    Record = new DataRecord();

                    if (fields.Count() == COLUMN_COUNT)
                    {
                        Record.duration = float.Parse(fields[0]);
                        Record.protocol_type = fields[1];
                        Record.service = fields[2];
                        Record.flag = fields[3];
                        Record.src_bytes = float.Parse(fields[4]);
                        Record.dst_bytes = float.Parse(fields[5]);
                        Record.land = float.Parse(fields[6]);
                        Record.wrong_fragment = float.Parse(fields[7]);
                        Record.urgent = float.Parse(fields[8]);
                        Record.hot = float.Parse(fields[9]);
                        Record.num_failed_logins = float.Parse(fields[10]);
                        Record.logged_in = float.Parse(fields[11]);
                        Record.num_compromised = float.Parse(fields[12]);
                        Record.root_shell = float.Parse(fields[13]);
                        Record.su_attempted = float.Parse(fields[14]);
                        Record.num_root = float.Parse(fields[15]);
                        Record.num_file_creations = float.Parse(fields[16]);
                        Record.num_shells = float.Parse(fields[17]);
                        Record.num_access_files = float.Parse(fields[18]);
                        Record.num_outbound_cmds = float.Parse(fields[19]);
                        Record.is_host_login = float.Parse(fields[20]);
                        Record.is_guest_login = float.Parse(fields[21]);
                        Record.count = float.Parse(fields[22]);
                        Record.srv_count = float.Parse(fields[23]);
                        Record.serror_rate = float.Parse(fields[24]);
                        Record.srv_serror_rate = float.Parse(fields[25]);
                        Record.rerror_rate = float.Parse(fields[26]);
                        Record.srv_rerror_rate = float.Parse(fields[27]);
                        Record.same_srv_rate = float.Parse(fields[28]);
                        Record.diff_srv_rate = float.Parse(fields[29]);
                        Record.srv_diff_host_rate = float.Parse(fields[30]);
                        Record.dst_host_count = float.Parse(fields[31]);
                        Record.dst_host_srv_count = float.Parse(fields[32]);
                        Record.dst_host_same_srv_rate = float.Parse(fields[33]);
                        Record.dst_host_diff_srv_rate = float.Parse(fields[34]);
                        Record.dst_host_same_src_port_rate = float.Parse(fields[35]);
                        Record.dst_host_srv_diff_host_rate = float.Parse(fields[36]);
                        Record.dst_host_serror_rate = float.Parse(fields[37]);
                        Record.dst_host_srv_serror_rate = float.Parse(fields[38]);
                        Record.dst_host_rerror_rate = float.Parse(fields[39]);
                        Record.dst_host_srv_rerror_rate = float.Parse(fields[40]);

                        Record.label = fields[41];
                        Record.isKnownAttackType = true;

                        switch (Record.label)
                        {
                            case "buffer_overflow":
                                Record.labelInt = 1;
                                break;
                            case "ftp_write":
                                Record.labelInt = 2;
                                break;
                            case "guess_passwd":
                                Record.labelInt = 3;
                                break;
                            case "imap":
                                Record.labelInt = 4;
                                break;
                            case "ipsweep":
                                Record.labelInt = 5;
                                break;
                            case "land":
                                Record.labelInt = 6;
                                break;
                            case "loadmodule":
                                Record.labelInt = 7;
                                break;
                            case "multihop":
                                Record.labelInt = 8;
                                break;
                            case "neptune":
                                Record.labelInt = 9;
                                break;
                            case "nmap":
                                Record.labelInt = 10;
                                break;
                            case "normal":
                                Record.labelInt = 11;
                                break;
                            case "perl":
                                Record.labelInt = 12;
                                break;
                            case "phf":
                                Record.labelInt = 13;
                                break;
                            case "pod":
                                Record.labelInt = 14;
                                break;
                            case "portsweep":
                                Record.labelInt = 15;
                                break;
                            case "rootkit":
                                Record.labelInt = 16;
                                break;
                            case "satan":
                                Record.labelInt = 17;
                                break;
                            case "smurf":
                                Record.labelInt = 18;
                                break;
                            case "spy":
                                Record.labelInt = 19;
                                break;
                            case "teardrop":
                                Record.labelInt = 20;
                                break;
                            case "warezclient":
                                Record.labelInt = 21;
                                break;
                            case "warezmaster":
                                Record.labelInt = 22;
                                break;
                            case "back":
                                Record.labelInt = 23;
                                break;
                            default:
                                //Console.WriteLine("ERROR: Invalid Label Type");
                                Record.isKnownAttackType = false;
                                break;
                        }

                        if(Record.label == "normal")
                        {
                            Record.labelInt = 0;
                        }
                        else
                        {
                            Record.labelInt = 1;
                        }
                    }
                    else
                    {


                    }

                    sourceIndex++;

                    // Getting here means we have a good record. Break the loop.
                    foundRecord = true;
                    break;
                }
                catch (Exception e)
                {
                    Console.WriteLine("ERROR: NextRecord failed for line: " + sourceIndex + " with exception: " + e.Message + " Stack: " + e.StackTrace);
                    sourceIndex++;
                }
            }
            return foundRecord;
        }

        public void Reset()
        {
            fileReader = new System.IO.StreamReader(sourceReport);
            // Burn column headers
            string line = fileReader.ReadLine();
            string[] fields = line.Split(',');
            if (fields.Count() != COLUMN_COUNT && fields.Count() != COLUMN_COUNT_TEST)
            {
                throw new Exception(string.Format("sourceReportParser column count [{0}] != expected COLUMN_COUNT [{1}]", fields.Count(), COLUMN_COUNT));
            }
            sourceIndex = 0;
        }