Haplo Prediction
predict haplogroups
haplo_train.c
Go to the documentation of this file.
00001 /*
00002  * This work is licensed under a Creative Commons 
00003  * Attribution-Noncommercial-Share Alike 3.0 United States License.
00004  * 
00005  *    http://creativecommons.org/licenses/by-nc-sa/3.0/us/
00006  * 
00007  * You are free:
00008  * 
00009  *    to Share - to copy, distribute, display, and perform the work
00010  *    to Remix - to make derivative works
00011  * 
00012  * Under the following conditions:
00013  * 
00014  *    Attribution. You must attribute the work in the manner specified by the
00015  *    author or licensor (but not in any way that suggests that they endorse you
00016  *    or your use of the work).
00017  * 
00018  *    Noncommercial. You may not use this work for commercial purposes.
00019  * 
00020  *    Share Alike. If you alter, transform, or build upon this work, you may
00021  *    distribute the resulting work only under the same or similar license to
00022  *    this one.
00023  * 
00024  * For any reuse or distribution, you must make clear to others the license
00025  * terms of this work. The best way to do this is by including this header.
00026  * 
00027  * Any of the above conditions can be waived if you get permission from the
00028  * copyright holder.
00029  * 
00030  * Apart from the remix rights granted under this license, nothing in this
00031  * license impairs or restricts the author's moral rights.
00032  */
00033 
00034 
00222 #include <config.h>
00223 
00224 #include <stdlib.h>
00225 #include <stdio.h>
00226 #include <assert.h>
00227 #include <string.h>
00228 #include <errno.h>
00229 
00230 #ifdef HAPLO_HAVE_DMALLOC
00231 #include <dmalloc.h>
00232 #endif
00233 
00234 #include <jwsc/base/error.h>
00235 #include <jwsc/base/option.h>
00236 #include <jwsc/vector/vector.h>
00237 #include <jwsc/matrix/matrix.h>
00238 #include <jwsc/matblock/matblock.h>
00239 
00240 #include "haplo_groups.h"
00241 #include "options.h"
00242 #include "input.h"
00243 #include "nb_freq.h"
00244 #include "nb_gauss.h"
00245 #include "nb_gmm.h"
00246 #include "mv_gmm.h"
00247 #include "mv_mmm.h"
00248 #ifdef HAPLO_ENABLE_SVM
00249 #include "svm_tree.h"
00250 #endif
00251 #ifdef HAPLO_ENABLE_WEKA
00252 #include "weka.h"
00253 #endif
00254 #include "nearest.h"
00255 
00256 
00257 #ifdef HAPLO_ENABLE_SVM
00258 #define NUM_SVM_OPTS 2
00259 #else
00260 #define NUM_SVM_OPTS 0
00261 #endif
00262 
00263 #ifdef HAPLO_ENABLE_WEKA
00264 #define NUM_WEKA_OPTS 4
00265 #else
00266 #define NUM_WEKA_OPTS 0
00267 #endif
00268 
00269 #define  NUM_OPTS_NO_ARG    0 + NUM_SHARED_OPTS_NO_ARG
00270 #define  NUM_OPTS_WITH_ARG  14 + NUM_SVM_OPTS + NUM_WEKA_OPTS + NUM_SHARED_OPTS_WITH_ARG
00271 
00272 
00273 #define  LABEL_COL         1
00274 #define  DATA_OUT_DIRNAME  0
00275 
00276 
00278 Option_no_arg opts_no_arg[NUM_OPTS_NO_ARG];
00279 
00281 Option_with_arg opts_with_arg[NUM_OPTS_WITH_ARG];
00282 
00284 const char* data_out_dirname;
00285 
00286 
00288 uint32_t get_num_opts_no_arg()
00289 {
00290     return NUM_OPTS_NO_ARG;
00291 }
00292 
00294 uint32_t get_num_opts_with_arg()
00295 {
00296     return NUM_OPTS_WITH_ARG;
00297 }
00298 
00300 void print_usage()
00301 {
00302     fprintf(stderr, "usage: haplo-train OPTIONS [data-fname | <stdin>]\n");
00303     print_options(stderr, 27, NUM_OPTS_NO_ARG, opts_no_arg, NUM_OPTS_WITH_ARG,
00304             opts_with_arg);
00305 }
00306 
00308 Error* process_data_out_dir_opt(Option_arg arg)
00309 {
00310     if (arg == NULL)
00311     {
00312         return JWSC_EARG("Option 'data-out-dir' requires an argument");
00313     }
00314     data_out_dirname = arg;
00315 
00316     return NULL;
00317 }
00318 
00320 static void init_train_options(void)
00321 {
00322     uint32_t i;
00323 
00324     char s_name;
00325     const char* l_name;
00326     const char* desc;
00327 
00328     Error* (*farg)(const char*);
00329 
00330     init_options(opts_no_arg, opts_with_arg);
00331 
00332     opts.label_col = LABEL_COL;
00333 
00334     i = NUM_SHARED_OPTS_WITH_ARG;
00335     l_name = "model-dir";
00336     s_name = 0;
00337     desc   = "Directory to put trained models in.";
00338     farg   = process_model_dir_opt;
00339     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00340 
00341     l_name = "data-out-dir";
00342     s_name = 0;
00343     desc   = "Directory to put the generated training data in for each model. The name of the model is used, so if this directory is set as the same as the model-dir, the models could be overwritten. The default is not to output the data.";
00344     farg   = process_data_out_dir_opt;
00345     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00346 
00347     l_name = "nb-freq";
00348     s_name = 0;
00349     desc   = "Naive Bayes non-parametric frequency model tree information.";
00350     farg   = process_nb_freq_opt;
00351     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00352 
00353     l_name = "nb-freq-dtd";
00354     s_name = 0;
00355     desc   = "Validate the naive Bayes non-parametric frequency model tree information XML file with this DTD.";
00356     farg   = process_nb_freq_dtd_opt;
00357     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00358 
00359     l_name = "nb-gauss";
00360     s_name = 0;
00361     desc   = "Naive Bayes Gaussian model tree information.";
00362     farg   = process_nb_gauss_opt;
00363     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00364 
00365     l_name = "nb-gauss-dtd";
00366     s_name = 0;
00367     desc   = "Validate the naive Bayes Gaussian model tree information XML file with this DTD.";
00368     farg   = process_nb_gauss_dtd_opt;
00369     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00370 
00371     l_name = "nb-gmm";
00372     s_name = 0;
00373     desc   = "Naive Bayes Gaussian mixture model tree information.";
00374     farg   = process_nb_gmm_opt;
00375     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00376 
00377     l_name = "nb-gmm-dtd";
00378     s_name = 0;
00379     desc   = "Validate the naive Bayes Gaussian mixture model tree information XML file with this DTD.";
00380     farg   = process_nb_gmm_dtd_opt;
00381     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00382 
00383     l_name = "mv-gmm";
00384     s_name = 0;
00385     desc   = "Multivariate Gaussian mixture model tree information.";
00386     farg   = process_mv_gmm_opt;
00387     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00388 
00389     l_name = "mv-gmm-dtd";
00390     s_name = 0;
00391     desc   = "Validate the multivariate Gaussian mixture model tree information XML file with this DTD.";
00392     farg   = process_mv_gmm_dtd_opt;
00393     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00394 
00395     l_name = "mv-mmm";
00396     s_name = 0;
00397     desc   = "Multivariate multinomial mixture model tree information.";
00398     farg   = process_mv_mmm_opt;
00399     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00400 
00401     l_name = "mv-mmm-dtd";
00402     s_name = 0;
00403     desc   = "Validate the multivariate multinomial mixture model tree information XML file with this DTD.";
00404     farg   = process_mv_mmm_dtd_opt;
00405     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00406 
00407 #ifdef HAPLO_ENABLE_SVM
00408     l_name = "svm";
00409     s_name = 0;
00410     desc   = "SVM model tree information.";
00411     farg   = process_svm_opt;
00412     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00413 
00414     l_name = "svm-dtd";
00415     s_name = 0;
00416     desc   = "Validate the SVM model tree information XML file with this DTD.";
00417     farg   = process_svm_dtd_opt;
00418     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00419 #endif
00420 
00421 #ifdef HAPLO_ENABLE_WEKA
00422     l_name = "weka-j48";
00423     s_name = 0;
00424     desc   = "Weka J48 model tree information.";
00425     farg   = process_weka_j48_opt;
00426     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00427 
00428     l_name = "weka-part";
00429     s_name = 0;
00430     desc   = "Weka PART model tree information.";
00431     farg   = process_weka_part_opt;
00432     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00433 
00434     l_name = "weka-jar";
00435     s_name = 0;
00436     desc   = "Weka java archive file. Required for using the Weka algorithms.";
00437     farg   = process_weka_jar_opt;
00438     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00439 
00440     l_name = "weka-dtd";
00441     s_name = 0;
00442     desc   = "Validate the Weka model tree information XML files with this DTD.";
00443     farg   = process_weka_dtd_opt;
00444     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00445 #endif
00446 
00447     l_name = "nearest";
00448     s_name = 0;
00449     desc   = "Nearest neighbor model information.";
00450     farg   = process_nearest_opt;
00451     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00452 
00453     l_name = "nearest-dtd";
00454     s_name = 0;
00455     desc   = "Validate the nearest neighbor model information XML file with this DTD.";
00456     farg   = process_nearest_dtd_opt;
00457     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00458     assert(i == NUM_OPTS_WITH_ARG);
00459 }
00460 
00461 
00463 int main(int argc, const char** argv)
00464 {
00465     int         argi;
00466     const char* data_fname = "/dev/stdin";
00467     Error*      err;
00468 
00469     Matblock_u8* ids    = NULL;
00470     Vector_u32* labels  = NULL;
00471     Matrix_i32* markers = NULL;
00472 
00473     NB_freq_model_tree*  nb_freq_tree  = NULL;
00474     NB_gauss_model_tree* nb_gauss_tree = NULL;
00475     NB_gmm_model_tree*   nb_gmm_tree   = NULL;
00476     MV_gmm_model_tree*   mv_gmm_tree   = NULL;
00477     MV_mmm_model_tree*   mv_mmm_tree   = NULL;
00478 #ifdef HAPLO_ENABLE_SVM
00479     SVM_model_tree*      svm_tree      = NULL;
00480 #endif
00481 #ifdef HAPLO_ENABLE_WEKA
00482     Weka_model_tree*     j48_tree      = NULL;
00483     Weka_model_tree*     part_tree     = NULL;
00484 #endif
00485     Nearest_model*       nearest_model = NULL;
00486 
00487     init_train_options();
00488 
00489     if ((err = process_options(argc, argv, &argi, NUM_OPTS_NO_ARG, opts_no_arg,
00490                     NUM_OPTS_WITH_ARG, opts_with_arg)) != NULL)
00491     {
00492         print_error_msg_exit("haplo-train", err->msg);
00493     }
00494 
00495     if ((argc - argi) == 1)
00496     {
00497         data_fname = argv[ argi ];
00498     }
00499 
00500     if ((err = read_haplo_groups(opts.labels_fname)))
00501     {
00502         print_error_msg_exit("haplo-train", err->msg);
00503     }
00504 
00505     if ((err = read_input(&ids, &labels, &markers, data_fname)))
00506     {
00507         print_error_msg_exit("haplo-train", err->msg);
00508     }
00509 
00510     if (!labels)
00511     {
00512         print_error_msg("haplo-train", NULL);
00513         print_error_msg_exit(data_fname, "No labels to train with");
00514     }
00515 
00516     if (opts.nb_freq_fname)
00517     {
00518         if ((err = train_nb_freq_model_tree(&nb_freq_tree, labels, markers,
00519                         opts.nb_freq_fname, opts.nb_freq_dtd_fname)) ||
00520             (err = write_nb_freq_model_tree(nb_freq_tree, opts.model_dirname)))
00521         {
00522             print_error_msg_exit("haplo-train", err->msg);
00523         }
00524         free_nb_freq_model_tree(nb_freq_tree);
00525     }
00526 
00527     if (opts.nb_gauss_fname)
00528     {
00529         if ((err = train_nb_gauss_model_tree(&nb_gauss_tree, labels, markers, 
00530                         opts.nb_gauss_fname, opts.nb_gauss_dtd_fname)) ||
00531             (err = write_nb_gauss_model_tree(nb_gauss_tree, 
00532                         opts.model_dirname)))
00533         {
00534             print_error_msg_exit("heplo-train", err->msg);
00535         }
00536         free_nb_gauss_model_tree(nb_gauss_tree);
00537     }
00538 
00539     if (opts.nb_gmm_fname)
00540     {
00541         if ((err = train_nb_gmm_model_tree(&nb_gmm_tree, labels, markers, 
00542                         opts.nb_gmm_fname, opts.nb_gmm_dtd_fname)) ||
00543             (err = write_nb_gmm_model_tree(nb_gmm_tree, opts.model_dirname)))
00544         {
00545             print_error_msg_exit("haplo-train", err->msg);
00546         }
00547         free_nb_gmm_model_tree(nb_gmm_tree);
00548     }
00549 
00550     if (opts.mv_gmm_fname)
00551     {
00552         if ((err = train_mv_gmm_model_tree(&mv_gmm_tree, labels, markers, 
00553                         opts.mv_gmm_fname, opts.mv_gmm_dtd_fname)) ||
00554             (err = write_mv_gmm_model_tree(mv_gmm_tree, opts.model_dirname)))
00555         {
00556             print_error_msg_exit("haplo-train", err->msg);
00557         }
00558         free_mv_gmm_model_tree(mv_gmm_tree);
00559     }
00560 
00561     if (opts.mv_mmm_fname)
00562     {
00563         if ((err = train_mv_mmm_model_tree(&mv_mmm_tree, labels, markers, 
00564                         opts.mv_mmm_fname, opts.mv_mmm_dtd_fname)) ||
00565             (err = write_mv_mmm_model_tree(mv_mmm_tree, opts.model_dirname)))
00566         {
00567             print_error_msg_exit("haplo-train", err->msg);
00568         }
00569         free_mv_mmm_model_tree(mv_mmm_tree);
00570     }
00571 
00572 #ifdef HAPLO_ENABLE_SVM
00573     if (opts.svm_fname)
00574     {
00575         if ((err = train_svm_model_tree(&svm_tree, labels, markers,
00576                         opts.svm_fname, opts.svm_dtd_fname)) ||
00577             (err = write_svm_model_tree(svm_tree, opts.model_dirname)))
00578         {
00579             print_error_msg_exit("heplo-train", err->msg);
00580         }
00581         free_svm_model_tree(svm_tree);
00582         if (data_out_dirname)
00583         {
00584             if ((err = write_svm_model_tree_training_data(labels, markers,
00585                             opts.svm_fname, opts.svm_dtd_fname,
00586                             data_out_dirname)))
00587             {
00588                 print_error_msg_exit("haplo-train", err->msg);
00589             }
00590         }
00591     }
00592 #endif
00593 
00594 #ifdef HAPLO_ENABLE_WEKA
00595     if (opts.weka_j48_fname)
00596     {
00597         if ((err = train_weka_j48_model_tree(&j48_tree, labels, markers,
00598                         opts.weka_j48_fname, opts.weka_dtd_fname,
00599                         opts.model_dirname, opts.weka_jar_fname)))
00600         {
00601             print_error_msg_exit("heplo-train", err->msg);
00602         }
00603         free_weka_model_tree(j48_tree);
00604     }
00605 
00606     if (opts.weka_part_fname)
00607     {
00608         if ((err = train_weka_part_model_tree(&part_tree, labels, markers,
00609                         opts.weka_part_fname, opts.weka_dtd_fname,
00610                         opts.model_dirname, opts.weka_jar_fname)))
00611         {
00612             print_error_msg_exit("heplo-train", err->msg);
00613         }
00614         free_weka_model_tree(part_tree);
00615     }
00616 #endif
00617 
00618     if (opts.nearest_fname)
00619     {
00620         if ((err = train_nearest_model(&nearest_model, labels, markers,
00621                         opts.nearest_fname, opts.nearest_dtd_fname)) ||
00622             (err = write_nearest_model(nearest_model, opts.model_dirname)))
00623         {
00624             print_error_msg_exit("haplo-train", err->msg);
00625         }
00626         free_nearest_model(nearest_model);
00627     }
00628 
00629     free_matblock_u8(ids);
00630     free_vector_u32(labels);
00631     free_matrix_i32(markers);
00632 
00633     if (get_num_unhandled_errors() > 0)
00634     {
00635         print_error_msg_exit("haplo-train", "Unhandled errors");
00636     }
00637 
00638     return EXIT_SUCCESS;
00639 }