Haplo Prediction
predict haplogroups
haplo_predict.c
Go to the documentation of this file.
00001 /*
00002  * This work is licensed under a Creative Commons 
00003  * Attribution-Noncommercial-Share Alike 3.0 United States License.
00004  * 
00005  *    http://creativecommons.org/licenses/by-nc-sa/3.0/us/
00006  * 
00007  * You are free:
00008  * 
00009  *    to Share - to copy, distribute, display, and perform the work
00010  *    to Remix - to make derivative works
00011  * 
00012  * Under the following conditions:
00013  * 
00014  *    Attribution. You must attribute the work in the manner specified by the
00015  *    author or licensor (but not in any way that suggests that they endorse you
00016  *    or your use of the work).
00017  * 
00018  *    Noncommercial. You may not use this work for commercial purposes.
00019  * 
00020  *    Share Alike. If you alter, transform, or build upon this work, you may
00021  *    distribute the resulting work only under the same or similar license to
00022  *    this one.
00023  * 
00024  * For any reuse or distribution, you must make clear to others the license
00025  * terms of this work. The best way to do this is by including this header.
00026  * 
00027  * Any of the above conditions can be waived if you get permission from the
00028  * copyright holder.
00029  * 
00030  * Apart from the remix rights granted under this license, nothing in this
00031  * license impairs or restricts the author's moral rights.
00032  */
00033 
00034 
00055 #include <config.h>
00056 
00057 #include <stdlib.h>
00058 #include <stdio.h>
00059 #include <string.h>
00060 #include <assert.h>
00061 #include <inttypes.h>
00062 
00063 #include <libxml/tree.h>
00064 
00065 #ifdef HAPLO_HAVE_PTHREAD
00066 #include <pthread.h>
00067 #endif
00068 
00069 #ifdef HAPLO_HAVE_DMALLOC
00070 #include <dmalloc.h>
00071 #endif
00072 
00073 #include <jwsc/base/error.h>
00074 #include <jwsc/base/option.h>
00075 #include <jwsc/base/file_io.h>
00076 #include <jwsc/vector/vector.h>
00077 #include <jwsc/matrix/matrix.h>
00078 #include <jwsc/matblock/matblock.h>
00079 
00080 #include "haplo_groups.h"
00081 #include "options.h"
00082 #include "output.h"
00083 #include "input.h"
00084 #include "xml.h"
00085 #include "nb_freq.h"
00086 #include "nb_gauss.h"
00087 #include "nb_gmm.h"
00088 #include "mv_gmm.h"
00089 #ifdef HAPLO_ENABLE_SVM
00090 #include "svm_tree.h"
00091 #endif
00092 #ifdef HAPLO_ENABLE_WEKA
00093 #include "weka.h"
00094 #endif
00095 #include "nearest.h"
00096 
00097 
00098 #ifdef HAPLO_ENABLE_SVM
00099 #define NUM_SVM_OPTS 2
00100 #else
00101 #define NUM_SVM_OPTS 0
00102 #endif
00103 
00104 #ifdef HAPLO_ENABLE_WEKA
00105 #define NUM_WEKA_OPTS 4
00106 #else
00107 #define NUM_WEKA_OPTS 0
00108 #endif
00109 
00110 #define  NUM_OPTS_NO_ARG    1 + NUM_SHARED_OPTS_NO_ARG
00111 #define  NUM_OPTS_WITH_ARG  13 + NUM_SVM_OPTS + NUM_WEKA_OPTS + NUM_SHARED_OPTS_WITH_ARG
00112 
00113 
00114 #define  NUM_ALGOS     8
00115 #define  LABEL_COL     0
00116 #define  OUTPUT_FNAME  "/dev/stdout"
00117 
00118 
00124 typedef struct 
00125 { 
00126     Vector_u32** labels;
00127     Vector_d**   confs;
00128     Matrix_i32*  markers; 
00129     uint32_t     n;
00130 #ifdef HAPLO_HAVE_PTHREAD
00131     pthread_mutex_t mutex;
00132 #endif
00133 }
00134 Predict_params;
00135 
00136 
00138 Option_no_arg opts_no_arg[NUM_OPTS_NO_ARG];
00139 
00141 Option_with_arg opts_with_arg[NUM_OPTS_WITH_ARG];
00142 
00144 static const char* output_fname = OUTPUT_FNAME;
00145 
00146 
00148 uint32_t get_num_opts_no_arg()
00149 {
00150     return NUM_OPTS_NO_ARG;
00151 }
00152 
00154 uint32_t get_num_opts_with_arg()
00155 {
00156     return NUM_OPTS_WITH_ARG;
00157 }
00158 
00160 void print_usage()
00161 {
00162     fprintf(stderr, "usage: haplo-predict OPTIONS [data-fname | <stdin>]\n");
00163     print_options(stderr, 27, NUM_OPTS_NO_ARG, opts_no_arg, NUM_OPTS_WITH_ARG,
00164             opts_with_arg);
00165 }
00166 
00168 Error* process_output_opt(Option_arg arg)
00169 {
00170     if (arg == NULL)
00171     {
00172         return JWSC_EARG("Option 'output' requires an argument");
00173     }
00174     output_fname = arg;
00175     return NULL;
00176 }
00177 
00179 static void init_predict_options(void)
00180 {
00181     uint32_t i;
00182 
00183     char s_name;
00184     const char* l_name;
00185     const char* desc;
00186 
00187     Error* (*fnoarg)();
00188     Error* (*farg)(const char*);
00189 
00190     init_options(opts_no_arg, opts_with_arg);
00191 
00192     opts.label_col = LABEL_COL;
00193 
00194     i = NUM_SHARED_OPTS_NO_ARG;
00195     l_name = "exclude-one";
00196     s_name = 0;
00197     desc   = "When performing the tandem prediction decision, exclude at most one prediction from the set of classification algorithms. There must be three or more algorithms in play for this to take effect.";
00198     fnoarg = process_exclude_one_opt;
00199     init_option_no_arg(&(opts_no_arg[i++]), l_name, s_name, desc, fnoarg);
00200     assert(i == NUM_OPTS_NO_ARG);
00201 
00202     i = NUM_SHARED_OPTS_WITH_ARG;
00203     l_name = "output";
00204     s_name = 0;
00205     desc   = "File to output the predictions to. The default is stdout.";
00206     farg   = process_output_opt;
00207     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00208 
00209     l_name = "model-dir";
00210     s_name = 0;
00211     desc   = "Directory containing the trained models.";
00212     farg   = process_model_dir_opt;
00213     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00214 
00215     l_name = "nb-freq";
00216     s_name = 0;
00217     desc   = "Naive Bayes non-parametric frequency model tree information.";
00218     farg   = process_nb_freq_opt;
00219     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00220 
00221     l_name = "nb-freq-dtd";
00222     s_name = 0;
00223     desc   = "Validate the naive Bayes non-parametric frequency model tree information XML file with this DTD.";
00224     farg   = process_nb_freq_dtd_opt;
00225     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00226 
00227     l_name = "nb-gauss";
00228     s_name = 0;
00229     desc   = "Naive Bayes Gaussian model tree information.";
00230     farg   = process_nb_gauss_opt;
00231     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00232 
00233     l_name = "nb-gauss-dtd";
00234     s_name = 0;
00235     desc   = "Validate the naive Bayes Gaussian model tree information XML file with this DTD.";
00236     farg   = process_nb_gauss_dtd_opt;
00237     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00238 
00239     l_name = "nb-gmm";
00240     s_name = 0;
00241     desc   = "Naive Bayes Gaussian mixture model tree information.";
00242     farg   = process_nb_gmm_opt;
00243     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00244 
00245     l_name = "nb-gmm-dtd";
00246     s_name = 0;
00247     desc   = "Validate the naive Bayes Gaussian mixture model tree information XML file with this DTD.";
00248     farg   = process_nb_gmm_dtd_opt;
00249     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00250 
00251     l_name = "mv-gmm";
00252     s_name = 0;
00253     desc   = "Multivariate Gaussian mixture model tree information.";
00254     farg   = process_mv_gmm_opt;
00255     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00256 
00257     l_name = "mv-gmm-dtd";
00258     s_name = 0;
00259     desc   = "Validate the multivariate Gaussian mixture model tree information XML file with this DTD.";
00260     farg   = process_mv_gmm_dtd_opt;
00261     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00262 
00263 #ifdef HAPLO_ENABLE_SVM
00264     l_name = "svm";
00265     s_name = 0;
00266     desc   = "SVM model tree information.";
00267     farg   = process_svm_opt;
00268     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00269 
00270     l_name = "svm-dtd";
00271     s_name = 0;
00272     desc   = "Validate the SVM model tree information XML file with this DTD.";
00273     farg   = process_svm_dtd_opt;
00274     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00275 #endif
00276 
00277 #ifdef HAPLO_ENABLE_WEKA
00278     l_name = "weka-j48";
00279     s_name = 0;
00280     desc   = "Weka J48 model tree information.";
00281     farg   = process_weka_j48_opt;
00282     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00283 
00284     l_name = "weka-part";
00285     s_name = 0;
00286     desc   = "Weka PART model tree information.";
00287     farg   = process_weka_part_opt;
00288     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00289 
00290     l_name = "weka-jar";
00291     s_name = 0;
00292     desc   = "Weka java archive file. Required for using the Weka algorithms.";
00293     farg   = process_weka_jar_opt;
00294     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00295 
00296     l_name = "weka-dtd";
00297     s_name = 0;
00298     desc   = "Validate the Weka model tree information XML files with this DTD.";
00299     farg   = process_weka_dtd_opt;
00300     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00301 #endif
00302 
00303     l_name = "nearest";
00304     s_name = 0;
00305     desc   = "Nearest neighbor model information.";
00306     farg   = process_nearest_opt;
00307     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00308 
00309     l_name = "nearest-dtd";
00310     s_name = 0;
00311     desc   = "Validate the nearest neighbor model information XML file with this DTD.";
00312     farg   = process_nearest_dtd_opt;
00313     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00314 
00315     l_name = "nearest-max-d";
00316     s_name = 0;
00317     desc   = "Maximum distance allowed for a nearest neighbor classification.";
00318     farg   = process_nearest_max_d_opt;
00319     init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg);
00320     assert(i == NUM_OPTS_WITH_ARG);
00321 }
00322 
00324 static uint8_t num_models_to_predict()
00325 {
00326     return (opts.nb_freq_fname       != 0) +
00327            (opts.nb_gauss_fname      != 0) +
00328            (opts.nb_gmm_fname        != 0) +
00329            (opts.mv_gmm_fname        != 0) +
00330 #ifdef HAPLO_ENABLE_SVM
00331            (opts.svm_fname           != 0) +
00332 #endif
00333 #ifdef HAPLO_ENABLE_WEKA
00334            (opts.weka_j48_fname      != 0) +
00335            (opts.weka_part_fname     != 0) +
00336 #endif
00337            (opts.nearest_fname       != 0);
00338 }
00339 
00341 static void predict_nb_freq
00342 (
00343     Vector_u32**      labels_out,
00344     Vector_d**        confs_out,
00345     const Matrix_i32* markers
00346 )
00347 {
00348     NB_freq_model_tree* tree = NULL;
00349     Error* err;
00350 
00351     if (!opts.nb_freq_fname)
00352         return;
00353 
00354     if ((err = read_nb_freq_model_tree(&tree, opts.nb_freq_fname,
00355                     opts.nb_freq_dtd_fname, opts.model_dirname)) ||
00356         (err = predict_labels_with_nb_freq_model_tree(labels_out, 
00357                     confs_out, markers, tree, 0)))
00358     {
00359         print_error_msg_exit("haplo-predict", err->msg);
00360     }
00361 
00362     free_nb_freq_model_tree(tree);
00363 }
00364 
00366 static void predict_nb_gauss
00367 (
00368     Vector_u32**      labels_out,
00369     Vector_d**        confs_out,
00370     const Matrix_i32* markers
00371 )
00372 {
00373     NB_gauss_model_tree* tree = NULL;
00374     Error* err;
00375 
00376     if (!opts.nb_gauss_fname)
00377         return;
00378 
00379     if ((err = read_nb_gauss_model_tree(&tree, opts.nb_gauss_fname, 
00380                     opts.nb_gauss_dtd_fname, opts.model_dirname)) ||
00381         (err = predict_labels_with_nb_gauss_model_tree(labels_out, 
00382                     confs_out, markers, tree, 0)))
00383     {
00384         print_error_msg_exit("haplo-predict", err->msg);
00385     }
00386 
00387     free_nb_gauss_model_tree(tree);
00388 }
00389 
00391 static void predict_nb_gmm
00392 (
00393     Vector_u32**      labels_out,
00394     Vector_d**        confs_out,
00395     const Matrix_i32* markers
00396 )
00397 {
00398     NB_gmm_model_tree* tree = NULL;
00399     Error* err;
00400 
00401     if (!opts.nb_gmm_fname)
00402         return;
00403 
00404     if ((err = read_nb_gmm_model_tree(&tree, opts.nb_gmm_fname, 
00405                     opts.nb_gmm_dtd_fname, opts.model_dirname)) ||
00406         (err = predict_labels_with_nb_gmm_model_tree(labels_out, 
00407                     confs_out, markers, tree, 0)))
00408     {
00409         print_error_msg_exit("haplo-predict", err->msg);
00410     }
00411 
00412     free_nb_gmm_model_tree(tree);
00413 }
00414 
00416 static void predict_mv_gmm
00417 (
00418     Vector_u32**      labels_out,
00419     Vector_d**        confs_out,
00420     const Matrix_i32* markers
00421 )
00422 {
00423     MV_gmm_model_tree* tree = NULL;
00424     Error* err;
00425 
00426     if (!opts.mv_gmm_fname)
00427         return;
00428 
00429     if ((err = read_mv_gmm_model_tree(&tree, opts.mv_gmm_fname, 
00430                     opts.mv_gmm_dtd_fname, opts.model_dirname)) ||
00431         (err = predict_labels_with_mv_gmm_model_tree(labels_out, 
00432                     confs_out, markers, tree, 0)))
00433     {
00434         print_error_msg_exit("haplo-predict", err->msg);
00435     }
00436 
00437     free_mv_gmm_model_tree(tree);
00438 }
00439 
00441 static void predict_svm
00442 (
00443     Vector_u32**      labels_out,
00444     Vector_d**        confs_out,
00445     const Matrix_i32* markers
00446 )
00447 {
00448 #ifdef HAPLO_ENABLE_SVM
00449     SVM_model_tree* tree = NULL;
00450     Error* err;
00451 
00452     if (!opts.svm_fname)
00453         return;
00454 
00455     if ((err = read_svm_model_tree(&tree, opts.svm_fname, 
00456                     opts.svm_dtd_fname, opts.model_dirname)) ||
00457         (err = predict_labels_with_svm_model_tree(labels_out, confs_out,
00458                     markers, tree)))
00459     {
00460         print_error_msg_exit("haplo-predict", err->msg);
00461     }
00462 
00463     free_svm_model_tree(tree);
00464 #else
00465     return;
00466 #endif
00467 }
00468 
00470 static void predict_j48
00471 (
00472     Vector_u32**      labels_out,
00473     Vector_d**        confs_out,
00474     const Matrix_i32* markers
00475 )
00476 {
00477 #ifdef HAPLO_ENABLE_WEKA
00478     Weka_model_tree* tree = NULL;
00479     Error* err;
00480 
00481     if (!opts.weka_j48_fname)
00482         return;
00483 
00484     if ((err = read_weka_model_tree(&tree, opts.weka_j48_fname, 
00485                     opts.weka_dtd_fname, opts.model_dirname)) ||
00486         (err = predict_labels_with_weka_j48_model_tree(labels_out, 
00487                     confs_out, markers, tree, opts.weka_jar_fname)))
00488     {
00489         print_error_msg_exit("haplo-predict", err->msg);
00490     }
00491 
00492     free_weka_model_tree(tree);
00493 #else
00494     return;
00495 #endif
00496 }
00497 
00499 static void predict_part
00500 (
00501     Vector_u32**      labels_out,
00502     Vector_d**        confs_out,
00503     const Matrix_i32* markers
00504 )
00505 {
00506 #ifdef HAPLO_ENABLE_WEKA
00507     Weka_model_tree* tree = NULL;
00508     Error* err;
00509 
00510     if (!opts.weka_part_fname)
00511         return;
00512 
00513     if ((err = read_weka_model_tree(&tree, opts.weka_part_fname, 
00514                     opts.weka_dtd_fname, opts.model_dirname)) ||
00515         (err = predict_labels_with_weka_part_model_tree(labels_out, 
00516                     confs_out, markers, tree, opts.weka_jar_fname)))
00517     {
00518         print_error_msg_exit("haplo-predict", err->msg);
00519     }
00520 
00521     free_weka_model_tree(tree);
00522 #else
00523     return;
00524 #endif
00525 }
00526 
00528 static void predict_nearest
00529 (
00530     Vector_u32**      labels_out,
00531     Vector_d**        confs_out,
00532     const Matrix_i32* markers
00533 )
00534 {
00535     Nearest_model* model = NULL;
00536     Error*         err;
00537 
00538     if (!opts.nearest_fname)
00539         return;
00540 
00541     if ((err = read_nearest_model(&model, opts.nearest_fname,
00542                     opts.nearest_dtd_fname, opts.model_dirname)) ||
00543         (err = predict_labels_with_nearest_model(labels_out, confs_out,
00544                     markers, model)))
00545     {
00546         print_error_msg_exit("haplo-predict", err->msg);
00547     }
00548 
00549     free_nearest_model(model);
00550 }
00551 
00553 static void write_results_header
00554 (
00555     const Matblock_u8* ids,
00556     const Vector_u32*  labels,
00557     const Vector_u32*  nb_freq_labels,
00558     const Vector_u32*  nb_gauss_labels,
00559     const Vector_u32*  nb_gmm_labels,
00560     const Vector_u32*  mv_gmm_labels,
00561     const Vector_u32*  svm_labels,
00562     const Vector_u32*  j48_labels,
00563     const Vector_u32*  part_labels,
00564     const Vector_u32*  nearest_labels,
00565     FILE*              fp
00566 )
00567 {
00568     uint32_t i;
00569 
00570     if (!opts.header_out || opts.output_format == HAPLO_OUTPUT_XML)
00571         return;
00572 
00573     if (ids)
00574     {
00575         for (i = 0; i < ids->num_rows; i++)
00576         {
00577             switch (opts.output_format)
00578             {
00579                 case HAPLO_OUTPUT_TXT:
00580                     fprintf(fp, "ID %-7d  ", i+1);
00581                     break;
00582                 case HAPLO_OUTPUT_CSV:
00583                     fprintf(fp, "ID %d,", i+1);
00584                     break;
00585                 case HAPLO_OUTPUT_XML:
00586                     break;
00587             }
00588         }
00589     }
00590 
00591     if (labels)
00592     {
00593         switch (opts.output_format)
00594         {
00595             case HAPLO_OUTPUT_TXT:
00596                 fprintf(fp, "%-10s  ", "Label");
00597                 break;
00598             case HAPLO_OUTPUT_CSV:
00599                 fprintf(fp, "%s,", "Label");
00600                 break;
00601             case HAPLO_OUTPUT_XML:
00602                 break;
00603         }
00604     }
00605 
00606     switch (opts.output_format)
00607     {
00608         case HAPLO_OUTPUT_TXT:
00609             fprintf(fp, "%-10s  %-4s", "Ancestor", "Type");
00610             if (nb_freq_labels) 
00611                 fprintf(fp, "  %-10s  %-5s", "NB-Freq", "Conf");
00612             if (nb_gauss_labels) 
00613                 fprintf(fp, "  %-10s  %-5s", "NB-Gauss", "Conf");
00614             if (nb_gmm_labels) 
00615                 fprintf(fp, "  %-10s  %-5s", "NB-Gmm", "Conf");
00616             if (mv_gmm_labels) 
00617                 fprintf(fp, "  %-10s  %-5s", "MV-Gmm", "Conf");
00618             if (svm_labels) 
00619                 fprintf(fp, "  %-10s  %-5s", "SVM", "Conf");
00620             if (j48_labels) 
00621                 fprintf(fp, "  %-10s  %-5s", "J48", "Conf");
00622             if (part_labels) 
00623                 fprintf(fp, "  %-10s  %-5s", "PART", "Conf");
00624             if (nearest_labels) 
00625                 fprintf(fp, "  %-10s  %-5s", "Nearest", "Dist");
00626             break;
00627         case HAPLO_OUTPUT_CSV:
00628             fprintf(fp, "%s,%s", "Ancestor", "Type");
00629             if (nb_freq_labels) 
00630                 fprintf(fp, ",%s,%s", "NB-Freq", "Conf");
00631             if (nb_gauss_labels) 
00632                 fprintf(fp, ",%s,%s", "NB-Gauss", "Conf");
00633             if (nb_gmm_labels) 
00634                 fprintf(fp, ",%s,%s", "NB-Gmm", "Conf");
00635             if (mv_gmm_labels) 
00636                 fprintf(fp, ",%s,%s", "MV-Gmm", "Conf");
00637             if (svm_labels) 
00638                 fprintf(fp, ",%s,%s", "SVM", "Conf");
00639             if (j48_labels) 
00640                 fprintf(fp, ",%s,%s", "J48", "Conf");
00641             if (part_labels) 
00642                 fprintf(fp, ",%s,%s", "PART", "Conf");
00643             if (nearest_labels) 
00644                 fprintf(fp, ",%s,%s", "Nearest", "Dist");
00645             break;
00646         case HAPLO_OUTPUT_XML:
00647             break;
00648     }
00649     fprintf(fp, "\n");
00650 }
00651 
00653 static void find_ancestors
00654 (
00655     Vector_u32**      ancestor_types_out,
00656     Vector_u32**      ancestor_labels_out,
00657     const Vector_u32* labels_1,
00658     const Vector_u32* labels_2,
00659     const Vector_u32* labels_3,
00660     const Vector_u32* labels_4,
00661     const Vector_u32* labels_5,
00662     const Vector_u32* labels_6,
00663     const Vector_u32* labels_7,
00664     const Vector_u32* labels_8
00665 )
00666 {
00667     uint32_t            n, nn, N;
00668     uint32_t            i;
00669     uint32_t            num_labels;
00670     uint32_t            ancestor_label;
00671     Haplo_ancestor_type ancestor_type;
00672     Vector_u32*         labels;
00673     Vector_u32*         labelss;
00674 
00675     N = 0;
00676 
00677     if (labels_1) {N++; num_labels = labels_1->num_elts;}
00678     if (labels_2) {N++; num_labels = labels_2->num_elts;}
00679     if (labels_3) {N++; num_labels = labels_3->num_elts;}
00680     if (labels_4) {N++; num_labels = labels_4->num_elts;}
00681     if (labels_5) {N++; num_labels = labels_5->num_elts;}
00682     if (labels_6) {N++; num_labels = labels_6->num_elts;}
00683     if (labels_7) {N++; num_labels = labels_7->num_elts;}
00684     if (labels_8) {N++; num_labels = labels_8->num_elts;}
00685 
00686     assert(N > 0);
00687 
00688     labels = NULL;
00689     create_vector_u32(&labels, N);
00690 
00691     create_vector_u32(ancestor_types_out, num_labels);
00692     create_vector_u32(ancestor_labels_out, num_labels);
00693 
00694     for (i = 0; i < num_labels; i++)
00695     {
00696         N = 0;
00697 
00698         if (labels_1) labels->elts[ N++ ] = labels_1->elts[ i ];
00699         if (labels_2) labels->elts[ N++ ] = labels_2->elts[ i ];
00700         if (labels_3) labels->elts[ N++ ] = labels_3->elts[ i ];
00701         if (labels_4) labels->elts[ N++ ] = labels_4->elts[ i ];
00702         if (labels_5) labels->elts[ N++ ] = labels_5->elts[ i ];
00703         if (labels_6) labels->elts[ N++ ] = labels_6->elts[ i ];
00704         if (labels_7) labels->elts[ N++ ] = labels_7->elts[ i ];
00705         if (labels_8) labels->elts[ N++ ] = labels_8->elts[ i ];
00706 
00707         ancestor_type = find_ancestor_index_of_set(&ancestor_label, labels);
00708 
00709         if (opts.exclude_one && ancestor_type == HAPLO_ANCESTOR_NONE && N > 3)
00710         {
00711             labelss = NULL;
00712             create_vector_u32(&labelss, N-1);
00713             for (n = 0; ancestor_type == HAPLO_ANCESTOR_NONE && n < N; n++)
00714             {
00715                 for (nn = 0; nn < N-1; nn++)
00716                 {
00717                     if (nn < n)
00718                     {
00719                         labelss->elts[ nn ] = labels->elts[ nn ];
00720                     }
00721                     else
00722                     {
00723                         labelss->elts[ nn ] = labels->elts[ nn+1 ];
00724                     }
00725                 }
00726 
00727                 ancestor_type = find_ancestor_index_of_set(&ancestor_label,
00728                         labelss);
00729             }
00730             free_vector_u32(labelss);
00731         }
00732 
00733         (*ancestor_types_out)->elts[ i ] = ancestor_type;
00734         (*ancestor_labels_out)->elts[ i ] = ancestor_label;
00735     }
00736 }
00737 
00739 static void write_results
00740 (
00741     const Matblock_u8* ids,
00742     const Vector_u32*  labels,
00743     const Matrix_i32*  markers,
00744     const Vector_u32*  nb_freq_labels,
00745     const Vector_d*    nb_freq_confs,
00746     const Vector_u32*  nb_gauss_labels,
00747     const Vector_d*    nb_gauss_confs,
00748     const Vector_u32*  nb_gmm_labels,
00749     const Vector_d*    nb_gmm_confs,
00750     const Vector_u32*  mv_gmm_labels,
00751     const Vector_d*    mv_gmm_confs,
00752     const Vector_u32*  svm_labels,
00753     const Vector_d*    svm_confs,
00754     const Vector_u32*  j48_labels,
00755     const Vector_d*    j48_confs,
00756     const Vector_u32*  part_labels,
00757     const Vector_d*    part_confs,
00758     const Vector_u32*  nearest_labels,
00759     const Vector_d*    nearest_dists,
00760     const Vector_u32*  ancestor_types,
00761     const Vector_u32*  ancestor_labels
00762 )
00763 {
00764     uint32_t    i;
00765     FILE*       fp        = NULL;
00766     xmlDoc*     xml_doc   = NULL;
00767     xmlNode*    xml_root  = NULL;
00768     xmlNode*    xml_node  = NULL;
00769     char*       xml_buf;
00770     uint32_t    xml_len;
00771     Error*      err;
00772 
00773     if ((err = open_output(&fp, &xml_doc, "haplo-predict-out",
00774                     "haplo-predict-out.dtd", output_fname)))
00775     {
00776         print_error_msg("haplo-predict", err->msg);
00777     }
00778 
00779     write_results_header(ids, labels, nb_freq_labels, nb_gauss_labels,
00780             nb_gmm_labels, mv_gmm_labels, svm_labels, j48_labels, part_labels,
00781             nearest_labels, fp);
00782 
00783     if (opts.output_format == HAPLO_OUTPUT_XML)
00784     {
00785         xml_root = xmlDocGetRootElement(xml_doc);
00786     }
00787 
00788     for (i = 0; i < markers->num_rows; i++)
00789     {
00790         if (opts.output_format == HAPLO_OUTPUT_XML)
00791         {
00792             xml_node = XMLNewChild(xml_root, "sample", NULL);
00793             xml_len = 256;
00794             xml_buf = malloc(xml_len*sizeof(char));
00795             snprintf(xml_buf, xml_len, "%d", i+1);
00796             XMLNewProp(xml_node, "number", xml_buf);
00797             free(xml_buf);
00798         }
00799 
00800         write_ids(ids, i, HAPLO_SEP_SUFFIX, fp, xml_node);
00801         write_label(labels, i, HAPLO_SEP_SUFFIX, fp, xml_node);
00802 
00803         write_ancestor_label(ancestor_types, ancestor_labels, i, 
00804                 HAPLO_SEP_NONE, fp, xml_node);
00805 
00806         write_prediction("nb-freq", nb_freq_labels, nb_freq_confs, i, 
00807                 HAPLO_SEP_PREFIX, fp, xml_node);
00808 
00809         write_prediction("nb-gauss", nb_gauss_labels, nb_gauss_confs, i, 
00810                 HAPLO_SEP_PREFIX, fp, xml_node);
00811 
00812         write_prediction("nb-gmm", nb_gmm_labels, nb_gmm_confs, i, 
00813                 HAPLO_SEP_PREFIX, fp, xml_node);
00814 
00815         write_prediction("mv-gmm", mv_gmm_labels, mv_gmm_confs, i, 
00816                 HAPLO_SEP_PREFIX, fp, xml_node);
00817 
00818         write_prediction("svm", svm_labels, svm_confs, i, 
00819                 HAPLO_SEP_PREFIX, fp, xml_node);
00820 
00821         write_prediction("j48", j48_labels, j48_confs, i, 
00822                 HAPLO_SEP_PREFIX, fp, xml_node);
00823 
00824         write_prediction("part", part_labels, part_confs, i, 
00825                 HAPLO_SEP_PREFIX, fp, xml_node);
00826 
00827         write_prediction("nearest", nearest_labels, nearest_dists, i, 
00828                 HAPLO_SEP_PREFIX, fp, xml_node);
00829 
00830         if (opts.output_format != HAPLO_OUTPUT_XML)
00831         {
00832             fprintf(fp, "\n");
00833         }
00834     }
00835 
00836     if ((err = close_output(fp, xml_doc, output_fname)))
00837     {
00838         print_error_msg("haplo-predict", err->msg);
00839     }
00840 }
00841 
00842 
00843 static void predict(Predict_params* p)
00844 #ifdef HAPLO_HAVE_PTHREAD
00845 {
00846     uint32_t n;
00847     typedef void (*f)(Vector_u32**, Vector_d**, const Matrix_i32*);
00848 
00849     f predict_f[NUM_ALGOS] = { 
00850                         predict_nb_freq, predict_nb_gauss, 
00851                         predict_nb_gmm, predict_mv_gmm, 
00852                         predict_svm, predict_j48,
00853                         predict_part, predict_nearest };
00854 
00855     pthread_mutex_lock(&(p->mutex));
00856     n = p->n;
00857 
00858     while (n < NUM_ALGOS)
00859     {
00860         pthread_mutex_unlock(&(p->mutex));
00861 
00862         p->labels[ n ] = NULL;
00863         p->confs[ n ] = NULL;
00864         predict_f[ n ](&(p->labels[ n ]), &(p->confs[ n ]), p->markers);
00865 
00866         pthread_mutex_lock(&(p->mutex));
00867         n = ++(p->n);
00868     }
00869 
00870     pthread_mutex_unlock(&(p->mutex));
00871 }
00872 #else
00873 {
00874     uint32_t n;
00875     typedef void (*f)(Vector_u32**, Vector_d**, const Matrix_i32*);
00876 
00877     f predict_f[NUM_ALGOS] = { 
00878                         predict_nb_freq, predict_nb_gauss, 
00879                         predict_nb_gmm, predict_mv_gmm, 
00880                         predict_svm, predict_j48,
00881                         predict_part, predict_nearest };
00882 
00883     for (n = p->n; n < NUM_ALGOS; n++)
00884     {
00885         p->labels[ n ] = NULL;
00886         p->confs[ n ] = NULL;
00887         predict_f[ n ](&(p->labels[ n ]), &(p->confs[ n ]), p->markers);
00888     }
00889 }
00890 #endif
00891 
00892 
00894 int main(int argc, const char** argv)
00895 {
00896     int         argi;
00897     const char* data_fname = "/dev/stdin";
00898     Error*      err;
00899 
00900     uint32_t n;
00901     typedef void* (*f)(void*);
00902     Predict_params p;
00903 
00904 #ifdef HAPLO_HAVE_PTHREAD
00905     uint32_t t;
00906     pthread_t* threads;
00907 #endif
00908 
00909     Matblock_u8* ids             = NULL;
00910     Vector_u32*  labels          = NULL;
00911     Matrix_i32*  markers         = NULL;
00912     Vector_u32*  ancestor_types  = NULL;
00913     Vector_u32*  ancestor_labels = NULL;
00914 
00915     Vector_u32*  pred_labels[NUM_ALGOS] = {NULL};
00916     Vector_d*    pred_confs[NUM_ALGOS]  = {NULL};
00917 
00918     init_predict_options();
00919 
00920     if ((err = process_options(argc, argv, &argi, NUM_OPTS_NO_ARG, opts_no_arg,
00921                     NUM_OPTS_WITH_ARG, opts_with_arg)) != NULL)
00922     {
00923         print_error_msg_exit("haplo-predict", err->msg);
00924     }
00925 
00926     if ((argc - argi) == 1)
00927     {
00928         data_fname = argv[ argi ];
00929     }
00930 
00931     if (num_models_to_predict() == 0)
00932     {
00933         print_error_msg_exit("haplo-predict", "No models to predict with");
00934     }
00935 
00936     if ((err = read_haplo_groups(opts.labels_fname)))
00937     {
00938         print_error_msg_exit("haplo-predict", err->msg);
00939     }
00940 
00941     if ((err = read_input(&ids, &labels, &markers, data_fname)))
00942     {
00943         print_error_msg_exit("haplo-predict", err->msg);
00944     }
00945 
00946     p.labels  = pred_labels;
00947     p.confs   = pred_confs;
00948     p.markers = markers;
00949     p.n       = 0;
00950 
00951 #ifdef HAPLO_HAVE_PTHREAD
00952     pthread_mutex_init(&(p.mutex), NULL);
00953     assert(threads = malloc(opts.num_threads*sizeof(pthread_t)));
00954     for (t = 0; t < opts.num_threads; t++)
00955     {
00956         pthread_create(&(threads[ t ]), NULL, (f)predict, &p);
00957     }
00958     for (t = 0; t < opts.num_threads; t++)
00959     {
00960         pthread_join(threads[ t ], NULL);
00961     }
00962     free(threads);
00963 #else
00964     predict(&p);
00965 #endif
00966 
00967     find_ancestors(&ancestor_types, &ancestor_labels, pred_labels[0],
00968             pred_labels[1], pred_labels[2], pred_labels[3],
00969             pred_labels[4], pred_labels[5], pred_labels[6],
00970             pred_labels[7]);
00971 
00972     write_results(ids, labels, markers, pred_labels[0], pred_confs[0],
00973             pred_labels[1], pred_confs[1], pred_labels[2], pred_confs[2],
00974             pred_labels[3], pred_confs[3], pred_labels[4], pred_confs[4],
00975             pred_labels[5], pred_confs[5], pred_labels[6], pred_confs[6],
00976             pred_labels[7], pred_confs[7], ancestor_types, ancestor_labels);
00977 
00978     free_matblock_u8(ids);
00979     free_vector_u32(labels);
00980     free_matrix_i32(markers);
00981     free_vector_u32(ancestor_types);
00982     free_vector_u32(ancestor_labels);
00983 
00984     for (n = 0; n < NUM_ALGOS; n++)
00985     {
00986         free_vector_u32(pred_labels[ n ]);
00987         free_vector_d(pred_confs[ n ]);
00988     }
00989 
00990     if (get_num_unhandled_errors() > 0)
00991     {
00992         print_error_msg_exit("haplo-predict", "Unhandled errors");
00993     }
00994 
00995     return EXIT_SUCCESS;
00996 }