Haplo Prediction
predict haplogroups
|
00001 /* 00002 * This work is licensed under a Creative Commons 00003 * Attribution-Noncommercial-Share Alike 3.0 United States License. 00004 * 00005 * http://creativecommons.org/licenses/by-nc-sa/3.0/us/ 00006 * 00007 * You are free: 00008 * 00009 * to Share - to copy, distribute, display, and perform the work 00010 * to Remix - to make derivative works 00011 * 00012 * Under the following conditions: 00013 * 00014 * Attribution. You must attribute the work in the manner specified by the 00015 * author or licensor (but not in any way that suggests that they endorse you 00016 * or your use of the work). 00017 * 00018 * Noncommercial. You may not use this work for commercial purposes. 00019 * 00020 * Share Alike. If you alter, transform, or build upon this work, you may 00021 * distribute the resulting work only under the same or similar license to 00022 * this one. 00023 * 00024 * For any reuse or distribution, you must make clear to others the license 00025 * terms of this work. The best way to do this is by including this header. 00026 * 00027 * Any of the above conditions can be waived if you get permission from the 00028 * copyright holder. 00029 * 00030 * Apart from the remix rights granted under this license, nothing in this 00031 * license impairs or restricts the author's moral rights. 00032 */ 00033 00034 00055 #include <config.h> 00056 00057 #include <stdlib.h> 00058 #include <stdio.h> 00059 #include <string.h> 00060 #include <assert.h> 00061 #include <inttypes.h> 00062 00063 #include <libxml/tree.h> 00064 00065 #ifdef HAPLO_HAVE_PTHREAD 00066 #include <pthread.h> 00067 #endif 00068 00069 #ifdef HAPLO_HAVE_DMALLOC 00070 #include <dmalloc.h> 00071 #endif 00072 00073 #include <jwsc/base/error.h> 00074 #include <jwsc/base/option.h> 00075 #include <jwsc/base/file_io.h> 00076 #include <jwsc/vector/vector.h> 00077 #include <jwsc/matrix/matrix.h> 00078 #include <jwsc/matblock/matblock.h> 00079 00080 #include "haplo_groups.h" 00081 #include "options.h" 00082 #include "output.h" 00083 #include "input.h" 00084 #include "xml.h" 00085 #include "nb_freq.h" 00086 #include "nb_gauss.h" 00087 #include "nb_gmm.h" 00088 #include "mv_gmm.h" 00089 #ifdef HAPLO_ENABLE_SVM 00090 #include "svm_tree.h" 00091 #endif 00092 #ifdef HAPLO_ENABLE_WEKA 00093 #include "weka.h" 00094 #endif 00095 #include "nearest.h" 00096 00097 00098 #ifdef HAPLO_ENABLE_SVM 00099 #define NUM_SVM_OPTS 2 00100 #else 00101 #define NUM_SVM_OPTS 0 00102 #endif 00103 00104 #ifdef HAPLO_ENABLE_WEKA 00105 #define NUM_WEKA_OPTS 4 00106 #else 00107 #define NUM_WEKA_OPTS 0 00108 #endif 00109 00110 #define NUM_OPTS_NO_ARG 1 + NUM_SHARED_OPTS_NO_ARG 00111 #define NUM_OPTS_WITH_ARG 13 + NUM_SVM_OPTS + NUM_WEKA_OPTS + NUM_SHARED_OPTS_WITH_ARG 00112 00113 00114 #define NUM_ALGOS 8 00115 #define LABEL_COL 0 00116 #define OUTPUT_FNAME "/dev/stdout" 00117 00118 00124 typedef struct 00125 { 00126 Vector_u32** labels; 00127 Vector_d** confs; 00128 Matrix_i32* markers; 00129 uint32_t n; 00130 #ifdef HAPLO_HAVE_PTHREAD 00131 pthread_mutex_t mutex; 00132 #endif 00133 } 00134 Predict_params; 00135 00136 00138 Option_no_arg opts_no_arg[NUM_OPTS_NO_ARG]; 00139 00141 Option_with_arg opts_with_arg[NUM_OPTS_WITH_ARG]; 00142 00144 static const char* output_fname = OUTPUT_FNAME; 00145 00146 00148 uint32_t get_num_opts_no_arg() 00149 { 00150 return NUM_OPTS_NO_ARG; 00151 } 00152 00154 uint32_t get_num_opts_with_arg() 00155 { 00156 return NUM_OPTS_WITH_ARG; 00157 } 00158 00160 void print_usage() 00161 { 00162 fprintf(stderr, "usage: haplo-predict OPTIONS [data-fname | <stdin>]\n"); 00163 print_options(stderr, 27, NUM_OPTS_NO_ARG, opts_no_arg, NUM_OPTS_WITH_ARG, 00164 opts_with_arg); 00165 } 00166 00168 Error* process_output_opt(Option_arg arg) 00169 { 00170 if (arg == NULL) 00171 { 00172 return JWSC_EARG("Option 'output' requires an argument"); 00173 } 00174 output_fname = arg; 00175 return NULL; 00176 } 00177 00179 static void init_predict_options(void) 00180 { 00181 uint32_t i; 00182 00183 char s_name; 00184 const char* l_name; 00185 const char* desc; 00186 00187 Error* (*fnoarg)(); 00188 Error* (*farg)(const char*); 00189 00190 init_options(opts_no_arg, opts_with_arg); 00191 00192 opts.label_col = LABEL_COL; 00193 00194 i = NUM_SHARED_OPTS_NO_ARG; 00195 l_name = "exclude-one"; 00196 s_name = 0; 00197 desc = "When performing the tandem prediction decision, exclude at most one prediction from the set of classification algorithms. There must be three or more algorithms in play for this to take effect."; 00198 fnoarg = process_exclude_one_opt; 00199 init_option_no_arg(&(opts_no_arg[i++]), l_name, s_name, desc, fnoarg); 00200 assert(i == NUM_OPTS_NO_ARG); 00201 00202 i = NUM_SHARED_OPTS_WITH_ARG; 00203 l_name = "output"; 00204 s_name = 0; 00205 desc = "File to output the predictions to. The default is stdout."; 00206 farg = process_output_opt; 00207 init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg); 00208 00209 l_name = "model-dir"; 00210 s_name = 0; 00211 desc = "Directory containing the trained models."; 00212 farg = process_model_dir_opt; 00213 init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg); 00214 00215 l_name = "nb-freq"; 00216 s_name = 0; 00217 desc = "Naive Bayes non-parametric frequency model tree information."; 00218 farg = process_nb_freq_opt; 00219 init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg); 00220 00221 l_name = "nb-freq-dtd"; 00222 s_name = 0; 00223 desc = "Validate the naive Bayes non-parametric frequency model tree information XML file with this DTD."; 00224 farg = process_nb_freq_dtd_opt; 00225 init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg); 00226 00227 l_name = "nb-gauss"; 00228 s_name = 0; 00229 desc = "Naive Bayes Gaussian model tree information."; 00230 farg = process_nb_gauss_opt; 00231 init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg); 00232 00233 l_name = "nb-gauss-dtd"; 00234 s_name = 0; 00235 desc = "Validate the naive Bayes Gaussian model tree information XML file with this DTD."; 00236 farg = process_nb_gauss_dtd_opt; 00237 init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg); 00238 00239 l_name = "nb-gmm"; 00240 s_name = 0; 00241 desc = "Naive Bayes Gaussian mixture model tree information."; 00242 farg = process_nb_gmm_opt; 00243 init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg); 00244 00245 l_name = "nb-gmm-dtd"; 00246 s_name = 0; 00247 desc = "Validate the naive Bayes Gaussian mixture model tree information XML file with this DTD."; 00248 farg = process_nb_gmm_dtd_opt; 00249 init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg); 00250 00251 l_name = "mv-gmm"; 00252 s_name = 0; 00253 desc = "Multivariate Gaussian mixture model tree information."; 00254 farg = process_mv_gmm_opt; 00255 init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg); 00256 00257 l_name = "mv-gmm-dtd"; 00258 s_name = 0; 00259 desc = "Validate the multivariate Gaussian mixture model tree information XML file with this DTD."; 00260 farg = process_mv_gmm_dtd_opt; 00261 init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg); 00262 00263 #ifdef HAPLO_ENABLE_SVM 00264 l_name = "svm"; 00265 s_name = 0; 00266 desc = "SVM model tree information."; 00267 farg = process_svm_opt; 00268 init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg); 00269 00270 l_name = "svm-dtd"; 00271 s_name = 0; 00272 desc = "Validate the SVM model tree information XML file with this DTD."; 00273 farg = process_svm_dtd_opt; 00274 init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg); 00275 #endif 00276 00277 #ifdef HAPLO_ENABLE_WEKA 00278 l_name = "weka-j48"; 00279 s_name = 0; 00280 desc = "Weka J48 model tree information."; 00281 farg = process_weka_j48_opt; 00282 init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg); 00283 00284 l_name = "weka-part"; 00285 s_name = 0; 00286 desc = "Weka PART model tree information."; 00287 farg = process_weka_part_opt; 00288 init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg); 00289 00290 l_name = "weka-jar"; 00291 s_name = 0; 00292 desc = "Weka java archive file. Required for using the Weka algorithms."; 00293 farg = process_weka_jar_opt; 00294 init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg); 00295 00296 l_name = "weka-dtd"; 00297 s_name = 0; 00298 desc = "Validate the Weka model tree information XML files with this DTD."; 00299 farg = process_weka_dtd_opt; 00300 init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg); 00301 #endif 00302 00303 l_name = "nearest"; 00304 s_name = 0; 00305 desc = "Nearest neighbor model information."; 00306 farg = process_nearest_opt; 00307 init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg); 00308 00309 l_name = "nearest-dtd"; 00310 s_name = 0; 00311 desc = "Validate the nearest neighbor model information XML file with this DTD."; 00312 farg = process_nearest_dtd_opt; 00313 init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg); 00314 00315 l_name = "nearest-max-d"; 00316 s_name = 0; 00317 desc = "Maximum distance allowed for a nearest neighbor classification."; 00318 farg = process_nearest_max_d_opt; 00319 init_option_with_arg(&(opts_with_arg[i++]), l_name, s_name, desc, farg); 00320 assert(i == NUM_OPTS_WITH_ARG); 00321 } 00322 00324 static uint8_t num_models_to_predict() 00325 { 00326 return (opts.nb_freq_fname != 0) + 00327 (opts.nb_gauss_fname != 0) + 00328 (opts.nb_gmm_fname != 0) + 00329 (opts.mv_gmm_fname != 0) + 00330 #ifdef HAPLO_ENABLE_SVM 00331 (opts.svm_fname != 0) + 00332 #endif 00333 #ifdef HAPLO_ENABLE_WEKA 00334 (opts.weka_j48_fname != 0) + 00335 (opts.weka_part_fname != 0) + 00336 #endif 00337 (opts.nearest_fname != 0); 00338 } 00339 00341 static void predict_nb_freq 00342 ( 00343 Vector_u32** labels_out, 00344 Vector_d** confs_out, 00345 const Matrix_i32* markers 00346 ) 00347 { 00348 NB_freq_model_tree* tree = NULL; 00349 Error* err; 00350 00351 if (!opts.nb_freq_fname) 00352 return; 00353 00354 if ((err = read_nb_freq_model_tree(&tree, opts.nb_freq_fname, 00355 opts.nb_freq_dtd_fname, opts.model_dirname)) || 00356 (err = predict_labels_with_nb_freq_model_tree(labels_out, 00357 confs_out, markers, tree, 0))) 00358 { 00359 print_error_msg_exit("haplo-predict", err->msg); 00360 } 00361 00362 free_nb_freq_model_tree(tree); 00363 } 00364 00366 static void predict_nb_gauss 00367 ( 00368 Vector_u32** labels_out, 00369 Vector_d** confs_out, 00370 const Matrix_i32* markers 00371 ) 00372 { 00373 NB_gauss_model_tree* tree = NULL; 00374 Error* err; 00375 00376 if (!opts.nb_gauss_fname) 00377 return; 00378 00379 if ((err = read_nb_gauss_model_tree(&tree, opts.nb_gauss_fname, 00380 opts.nb_gauss_dtd_fname, opts.model_dirname)) || 00381 (err = predict_labels_with_nb_gauss_model_tree(labels_out, 00382 confs_out, markers, tree, 0))) 00383 { 00384 print_error_msg_exit("haplo-predict", err->msg); 00385 } 00386 00387 free_nb_gauss_model_tree(tree); 00388 } 00389 00391 static void predict_nb_gmm 00392 ( 00393 Vector_u32** labels_out, 00394 Vector_d** confs_out, 00395 const Matrix_i32* markers 00396 ) 00397 { 00398 NB_gmm_model_tree* tree = NULL; 00399 Error* err; 00400 00401 if (!opts.nb_gmm_fname) 00402 return; 00403 00404 if ((err = read_nb_gmm_model_tree(&tree, opts.nb_gmm_fname, 00405 opts.nb_gmm_dtd_fname, opts.model_dirname)) || 00406 (err = predict_labels_with_nb_gmm_model_tree(labels_out, 00407 confs_out, markers, tree, 0))) 00408 { 00409 print_error_msg_exit("haplo-predict", err->msg); 00410 } 00411 00412 free_nb_gmm_model_tree(tree); 00413 } 00414 00416 static void predict_mv_gmm 00417 ( 00418 Vector_u32** labels_out, 00419 Vector_d** confs_out, 00420 const Matrix_i32* markers 00421 ) 00422 { 00423 MV_gmm_model_tree* tree = NULL; 00424 Error* err; 00425 00426 if (!opts.mv_gmm_fname) 00427 return; 00428 00429 if ((err = read_mv_gmm_model_tree(&tree, opts.mv_gmm_fname, 00430 opts.mv_gmm_dtd_fname, opts.model_dirname)) || 00431 (err = predict_labels_with_mv_gmm_model_tree(labels_out, 00432 confs_out, markers, tree, 0))) 00433 { 00434 print_error_msg_exit("haplo-predict", err->msg); 00435 } 00436 00437 free_mv_gmm_model_tree(tree); 00438 } 00439 00441 static void predict_svm 00442 ( 00443 Vector_u32** labels_out, 00444 Vector_d** confs_out, 00445 const Matrix_i32* markers 00446 ) 00447 { 00448 #ifdef HAPLO_ENABLE_SVM 00449 SVM_model_tree* tree = NULL; 00450 Error* err; 00451 00452 if (!opts.svm_fname) 00453 return; 00454 00455 if ((err = read_svm_model_tree(&tree, opts.svm_fname, 00456 opts.svm_dtd_fname, opts.model_dirname)) || 00457 (err = predict_labels_with_svm_model_tree(labels_out, confs_out, 00458 markers, tree))) 00459 { 00460 print_error_msg_exit("haplo-predict", err->msg); 00461 } 00462 00463 free_svm_model_tree(tree); 00464 #else 00465 return; 00466 #endif 00467 } 00468 00470 static void predict_j48 00471 ( 00472 Vector_u32** labels_out, 00473 Vector_d** confs_out, 00474 const Matrix_i32* markers 00475 ) 00476 { 00477 #ifdef HAPLO_ENABLE_WEKA 00478 Weka_model_tree* tree = NULL; 00479 Error* err; 00480 00481 if (!opts.weka_j48_fname) 00482 return; 00483 00484 if ((err = read_weka_model_tree(&tree, opts.weka_j48_fname, 00485 opts.weka_dtd_fname, opts.model_dirname)) || 00486 (err = predict_labels_with_weka_j48_model_tree(labels_out, 00487 confs_out, markers, tree, opts.weka_jar_fname))) 00488 { 00489 print_error_msg_exit("haplo-predict", err->msg); 00490 } 00491 00492 free_weka_model_tree(tree); 00493 #else 00494 return; 00495 #endif 00496 } 00497 00499 static void predict_part 00500 ( 00501 Vector_u32** labels_out, 00502 Vector_d** confs_out, 00503 const Matrix_i32* markers 00504 ) 00505 { 00506 #ifdef HAPLO_ENABLE_WEKA 00507 Weka_model_tree* tree = NULL; 00508 Error* err; 00509 00510 if (!opts.weka_part_fname) 00511 return; 00512 00513 if ((err = read_weka_model_tree(&tree, opts.weka_part_fname, 00514 opts.weka_dtd_fname, opts.model_dirname)) || 00515 (err = predict_labels_with_weka_part_model_tree(labels_out, 00516 confs_out, markers, tree, opts.weka_jar_fname))) 00517 { 00518 print_error_msg_exit("haplo-predict", err->msg); 00519 } 00520 00521 free_weka_model_tree(tree); 00522 #else 00523 return; 00524 #endif 00525 } 00526 00528 static void predict_nearest 00529 ( 00530 Vector_u32** labels_out, 00531 Vector_d** confs_out, 00532 const Matrix_i32* markers 00533 ) 00534 { 00535 Nearest_model* model = NULL; 00536 Error* err; 00537 00538 if (!opts.nearest_fname) 00539 return; 00540 00541 if ((err = read_nearest_model(&model, opts.nearest_fname, 00542 opts.nearest_dtd_fname, opts.model_dirname)) || 00543 (err = predict_labels_with_nearest_model(labels_out, confs_out, 00544 markers, model))) 00545 { 00546 print_error_msg_exit("haplo-predict", err->msg); 00547 } 00548 00549 free_nearest_model(model); 00550 } 00551 00553 static void write_results_header 00554 ( 00555 const Matblock_u8* ids, 00556 const Vector_u32* labels, 00557 const Vector_u32* nb_freq_labels, 00558 const Vector_u32* nb_gauss_labels, 00559 const Vector_u32* nb_gmm_labels, 00560 const Vector_u32* mv_gmm_labels, 00561 const Vector_u32* svm_labels, 00562 const Vector_u32* j48_labels, 00563 const Vector_u32* part_labels, 00564 const Vector_u32* nearest_labels, 00565 FILE* fp 00566 ) 00567 { 00568 uint32_t i; 00569 00570 if (!opts.header_out || opts.output_format == HAPLO_OUTPUT_XML) 00571 return; 00572 00573 if (ids) 00574 { 00575 for (i = 0; i < ids->num_rows; i++) 00576 { 00577 switch (opts.output_format) 00578 { 00579 case HAPLO_OUTPUT_TXT: 00580 fprintf(fp, "ID %-7d ", i+1); 00581 break; 00582 case HAPLO_OUTPUT_CSV: 00583 fprintf(fp, "ID %d,", i+1); 00584 break; 00585 case HAPLO_OUTPUT_XML: 00586 break; 00587 } 00588 } 00589 } 00590 00591 if (labels) 00592 { 00593 switch (opts.output_format) 00594 { 00595 case HAPLO_OUTPUT_TXT: 00596 fprintf(fp, "%-10s ", "Label"); 00597 break; 00598 case HAPLO_OUTPUT_CSV: 00599 fprintf(fp, "%s,", "Label"); 00600 break; 00601 case HAPLO_OUTPUT_XML: 00602 break; 00603 } 00604 } 00605 00606 switch (opts.output_format) 00607 { 00608 case HAPLO_OUTPUT_TXT: 00609 fprintf(fp, "%-10s %-4s", "Ancestor", "Type"); 00610 if (nb_freq_labels) 00611 fprintf(fp, " %-10s %-5s", "NB-Freq", "Conf"); 00612 if (nb_gauss_labels) 00613 fprintf(fp, " %-10s %-5s", "NB-Gauss", "Conf"); 00614 if (nb_gmm_labels) 00615 fprintf(fp, " %-10s %-5s", "NB-Gmm", "Conf"); 00616 if (mv_gmm_labels) 00617 fprintf(fp, " %-10s %-5s", "MV-Gmm", "Conf"); 00618 if (svm_labels) 00619 fprintf(fp, " %-10s %-5s", "SVM", "Conf"); 00620 if (j48_labels) 00621 fprintf(fp, " %-10s %-5s", "J48", "Conf"); 00622 if (part_labels) 00623 fprintf(fp, " %-10s %-5s", "PART", "Conf"); 00624 if (nearest_labels) 00625 fprintf(fp, " %-10s %-5s", "Nearest", "Dist"); 00626 break; 00627 case HAPLO_OUTPUT_CSV: 00628 fprintf(fp, "%s,%s", "Ancestor", "Type"); 00629 if (nb_freq_labels) 00630 fprintf(fp, ",%s,%s", "NB-Freq", "Conf"); 00631 if (nb_gauss_labels) 00632 fprintf(fp, ",%s,%s", "NB-Gauss", "Conf"); 00633 if (nb_gmm_labels) 00634 fprintf(fp, ",%s,%s", "NB-Gmm", "Conf"); 00635 if (mv_gmm_labels) 00636 fprintf(fp, ",%s,%s", "MV-Gmm", "Conf"); 00637 if (svm_labels) 00638 fprintf(fp, ",%s,%s", "SVM", "Conf"); 00639 if (j48_labels) 00640 fprintf(fp, ",%s,%s", "J48", "Conf"); 00641 if (part_labels) 00642 fprintf(fp, ",%s,%s", "PART", "Conf"); 00643 if (nearest_labels) 00644 fprintf(fp, ",%s,%s", "Nearest", "Dist"); 00645 break; 00646 case HAPLO_OUTPUT_XML: 00647 break; 00648 } 00649 fprintf(fp, "\n"); 00650 } 00651 00653 static void find_ancestors 00654 ( 00655 Vector_u32** ancestor_types_out, 00656 Vector_u32** ancestor_labels_out, 00657 const Vector_u32* labels_1, 00658 const Vector_u32* labels_2, 00659 const Vector_u32* labels_3, 00660 const Vector_u32* labels_4, 00661 const Vector_u32* labels_5, 00662 const Vector_u32* labels_6, 00663 const Vector_u32* labels_7, 00664 const Vector_u32* labels_8 00665 ) 00666 { 00667 uint32_t n, nn, N; 00668 uint32_t i; 00669 uint32_t num_labels; 00670 uint32_t ancestor_label; 00671 Haplo_ancestor_type ancestor_type; 00672 Vector_u32* labels; 00673 Vector_u32* labelss; 00674 00675 N = 0; 00676 00677 if (labels_1) {N++; num_labels = labels_1->num_elts;} 00678 if (labels_2) {N++; num_labels = labels_2->num_elts;} 00679 if (labels_3) {N++; num_labels = labels_3->num_elts;} 00680 if (labels_4) {N++; num_labels = labels_4->num_elts;} 00681 if (labels_5) {N++; num_labels = labels_5->num_elts;} 00682 if (labels_6) {N++; num_labels = labels_6->num_elts;} 00683 if (labels_7) {N++; num_labels = labels_7->num_elts;} 00684 if (labels_8) {N++; num_labels = labels_8->num_elts;} 00685 00686 assert(N > 0); 00687 00688 labels = NULL; 00689 create_vector_u32(&labels, N); 00690 00691 create_vector_u32(ancestor_types_out, num_labels); 00692 create_vector_u32(ancestor_labels_out, num_labels); 00693 00694 for (i = 0; i < num_labels; i++) 00695 { 00696 N = 0; 00697 00698 if (labels_1) labels->elts[ N++ ] = labels_1->elts[ i ]; 00699 if (labels_2) labels->elts[ N++ ] = labels_2->elts[ i ]; 00700 if (labels_3) labels->elts[ N++ ] = labels_3->elts[ i ]; 00701 if (labels_4) labels->elts[ N++ ] = labels_4->elts[ i ]; 00702 if (labels_5) labels->elts[ N++ ] = labels_5->elts[ i ]; 00703 if (labels_6) labels->elts[ N++ ] = labels_6->elts[ i ]; 00704 if (labels_7) labels->elts[ N++ ] = labels_7->elts[ i ]; 00705 if (labels_8) labels->elts[ N++ ] = labels_8->elts[ i ]; 00706 00707 ancestor_type = find_ancestor_index_of_set(&ancestor_label, labels); 00708 00709 if (opts.exclude_one && ancestor_type == HAPLO_ANCESTOR_NONE && N > 3) 00710 { 00711 labelss = NULL; 00712 create_vector_u32(&labelss, N-1); 00713 for (n = 0; ancestor_type == HAPLO_ANCESTOR_NONE && n < N; n++) 00714 { 00715 for (nn = 0; nn < N-1; nn++) 00716 { 00717 if (nn < n) 00718 { 00719 labelss->elts[ nn ] = labels->elts[ nn ]; 00720 } 00721 else 00722 { 00723 labelss->elts[ nn ] = labels->elts[ nn+1 ]; 00724 } 00725 } 00726 00727 ancestor_type = find_ancestor_index_of_set(&ancestor_label, 00728 labelss); 00729 } 00730 free_vector_u32(labelss); 00731 } 00732 00733 (*ancestor_types_out)->elts[ i ] = ancestor_type; 00734 (*ancestor_labels_out)->elts[ i ] = ancestor_label; 00735 } 00736 } 00737 00739 static void write_results 00740 ( 00741 const Matblock_u8* ids, 00742 const Vector_u32* labels, 00743 const Matrix_i32* markers, 00744 const Vector_u32* nb_freq_labels, 00745 const Vector_d* nb_freq_confs, 00746 const Vector_u32* nb_gauss_labels, 00747 const Vector_d* nb_gauss_confs, 00748 const Vector_u32* nb_gmm_labels, 00749 const Vector_d* nb_gmm_confs, 00750 const Vector_u32* mv_gmm_labels, 00751 const Vector_d* mv_gmm_confs, 00752 const Vector_u32* svm_labels, 00753 const Vector_d* svm_confs, 00754 const Vector_u32* j48_labels, 00755 const Vector_d* j48_confs, 00756 const Vector_u32* part_labels, 00757 const Vector_d* part_confs, 00758 const Vector_u32* nearest_labels, 00759 const Vector_d* nearest_dists, 00760 const Vector_u32* ancestor_types, 00761 const Vector_u32* ancestor_labels 00762 ) 00763 { 00764 uint32_t i; 00765 FILE* fp = NULL; 00766 xmlDoc* xml_doc = NULL; 00767 xmlNode* xml_root = NULL; 00768 xmlNode* xml_node = NULL; 00769 char* xml_buf; 00770 uint32_t xml_len; 00771 Error* err; 00772 00773 if ((err = open_output(&fp, &xml_doc, "haplo-predict-out", 00774 "haplo-predict-out.dtd", output_fname))) 00775 { 00776 print_error_msg("haplo-predict", err->msg); 00777 } 00778 00779 write_results_header(ids, labels, nb_freq_labels, nb_gauss_labels, 00780 nb_gmm_labels, mv_gmm_labels, svm_labels, j48_labels, part_labels, 00781 nearest_labels, fp); 00782 00783 if (opts.output_format == HAPLO_OUTPUT_XML) 00784 { 00785 xml_root = xmlDocGetRootElement(xml_doc); 00786 } 00787 00788 for (i = 0; i < markers->num_rows; i++) 00789 { 00790 if (opts.output_format == HAPLO_OUTPUT_XML) 00791 { 00792 xml_node = XMLNewChild(xml_root, "sample", NULL); 00793 xml_len = 256; 00794 xml_buf = malloc(xml_len*sizeof(char)); 00795 snprintf(xml_buf, xml_len, "%d", i+1); 00796 XMLNewProp(xml_node, "number", xml_buf); 00797 free(xml_buf); 00798 } 00799 00800 write_ids(ids, i, HAPLO_SEP_SUFFIX, fp, xml_node); 00801 write_label(labels, i, HAPLO_SEP_SUFFIX, fp, xml_node); 00802 00803 write_ancestor_label(ancestor_types, ancestor_labels, i, 00804 HAPLO_SEP_NONE, fp, xml_node); 00805 00806 write_prediction("nb-freq", nb_freq_labels, nb_freq_confs, i, 00807 HAPLO_SEP_PREFIX, fp, xml_node); 00808 00809 write_prediction("nb-gauss", nb_gauss_labels, nb_gauss_confs, i, 00810 HAPLO_SEP_PREFIX, fp, xml_node); 00811 00812 write_prediction("nb-gmm", nb_gmm_labels, nb_gmm_confs, i, 00813 HAPLO_SEP_PREFIX, fp, xml_node); 00814 00815 write_prediction("mv-gmm", mv_gmm_labels, mv_gmm_confs, i, 00816 HAPLO_SEP_PREFIX, fp, xml_node); 00817 00818 write_prediction("svm", svm_labels, svm_confs, i, 00819 HAPLO_SEP_PREFIX, fp, xml_node); 00820 00821 write_prediction("j48", j48_labels, j48_confs, i, 00822 HAPLO_SEP_PREFIX, fp, xml_node); 00823 00824 write_prediction("part", part_labels, part_confs, i, 00825 HAPLO_SEP_PREFIX, fp, xml_node); 00826 00827 write_prediction("nearest", nearest_labels, nearest_dists, i, 00828 HAPLO_SEP_PREFIX, fp, xml_node); 00829 00830 if (opts.output_format != HAPLO_OUTPUT_XML) 00831 { 00832 fprintf(fp, "\n"); 00833 } 00834 } 00835 00836 if ((err = close_output(fp, xml_doc, output_fname))) 00837 { 00838 print_error_msg("haplo-predict", err->msg); 00839 } 00840 } 00841 00842 00843 static void predict(Predict_params* p) 00844 #ifdef HAPLO_HAVE_PTHREAD 00845 { 00846 uint32_t n; 00847 typedef void (*f)(Vector_u32**, Vector_d**, const Matrix_i32*); 00848 00849 f predict_f[NUM_ALGOS] = { 00850 predict_nb_freq, predict_nb_gauss, 00851 predict_nb_gmm, predict_mv_gmm, 00852 predict_svm, predict_j48, 00853 predict_part, predict_nearest }; 00854 00855 pthread_mutex_lock(&(p->mutex)); 00856 n = p->n; 00857 00858 while (n < NUM_ALGOS) 00859 { 00860 pthread_mutex_unlock(&(p->mutex)); 00861 00862 p->labels[ n ] = NULL; 00863 p->confs[ n ] = NULL; 00864 predict_f[ n ](&(p->labels[ n ]), &(p->confs[ n ]), p->markers); 00865 00866 pthread_mutex_lock(&(p->mutex)); 00867 n = ++(p->n); 00868 } 00869 00870 pthread_mutex_unlock(&(p->mutex)); 00871 } 00872 #else 00873 { 00874 uint32_t n; 00875 typedef void (*f)(Vector_u32**, Vector_d**, const Matrix_i32*); 00876 00877 f predict_f[NUM_ALGOS] = { 00878 predict_nb_freq, predict_nb_gauss, 00879 predict_nb_gmm, predict_mv_gmm, 00880 predict_svm, predict_j48, 00881 predict_part, predict_nearest }; 00882 00883 for (n = p->n; n < NUM_ALGOS; n++) 00884 { 00885 p->labels[ n ] = NULL; 00886 p->confs[ n ] = NULL; 00887 predict_f[ n ](&(p->labels[ n ]), &(p->confs[ n ]), p->markers); 00888 } 00889 } 00890 #endif 00891 00892 00894 int main(int argc, const char** argv) 00895 { 00896 int argi; 00897 const char* data_fname = "/dev/stdin"; 00898 Error* err; 00899 00900 uint32_t n; 00901 typedef void* (*f)(void*); 00902 Predict_params p; 00903 00904 #ifdef HAPLO_HAVE_PTHREAD 00905 uint32_t t; 00906 pthread_t* threads; 00907 #endif 00908 00909 Matblock_u8* ids = NULL; 00910 Vector_u32* labels = NULL; 00911 Matrix_i32* markers = NULL; 00912 Vector_u32* ancestor_types = NULL; 00913 Vector_u32* ancestor_labels = NULL; 00914 00915 Vector_u32* pred_labels[NUM_ALGOS] = {NULL}; 00916 Vector_d* pred_confs[NUM_ALGOS] = {NULL}; 00917 00918 init_predict_options(); 00919 00920 if ((err = process_options(argc, argv, &argi, NUM_OPTS_NO_ARG, opts_no_arg, 00921 NUM_OPTS_WITH_ARG, opts_with_arg)) != NULL) 00922 { 00923 print_error_msg_exit("haplo-predict", err->msg); 00924 } 00925 00926 if ((argc - argi) == 1) 00927 { 00928 data_fname = argv[ argi ]; 00929 } 00930 00931 if (num_models_to_predict() == 0) 00932 { 00933 print_error_msg_exit("haplo-predict", "No models to predict with"); 00934 } 00935 00936 if ((err = read_haplo_groups(opts.labels_fname))) 00937 { 00938 print_error_msg_exit("haplo-predict", err->msg); 00939 } 00940 00941 if ((err = read_input(&ids, &labels, &markers, data_fname))) 00942 { 00943 print_error_msg_exit("haplo-predict", err->msg); 00944 } 00945 00946 p.labels = pred_labels; 00947 p.confs = pred_confs; 00948 p.markers = markers; 00949 p.n = 0; 00950 00951 #ifdef HAPLO_HAVE_PTHREAD 00952 pthread_mutex_init(&(p.mutex), NULL); 00953 assert(threads = malloc(opts.num_threads*sizeof(pthread_t))); 00954 for (t = 0; t < opts.num_threads; t++) 00955 { 00956 pthread_create(&(threads[ t ]), NULL, (f)predict, &p); 00957 } 00958 for (t = 0; t < opts.num_threads; t++) 00959 { 00960 pthread_join(threads[ t ], NULL); 00961 } 00962 free(threads); 00963 #else 00964 predict(&p); 00965 #endif 00966 00967 find_ancestors(&ancestor_types, &ancestor_labels, pred_labels[0], 00968 pred_labels[1], pred_labels[2], pred_labels[3], 00969 pred_labels[4], pred_labels[5], pred_labels[6], 00970 pred_labels[7]); 00971 00972 write_results(ids, labels, markers, pred_labels[0], pred_confs[0], 00973 pred_labels[1], pred_confs[1], pred_labels[2], pred_confs[2], 00974 pred_labels[3], pred_confs[3], pred_labels[4], pred_confs[4], 00975 pred_labels[5], pred_confs[5], pred_labels[6], pred_confs[6], 00976 pred_labels[7], pred_confs[7], ancestor_types, ancestor_labels); 00977 00978 free_matblock_u8(ids); 00979 free_vector_u32(labels); 00980 free_matrix_i32(markers); 00981 free_vector_u32(ancestor_types); 00982 free_vector_u32(ancestor_labels); 00983 00984 for (n = 0; n < NUM_ALGOS; n++) 00985 { 00986 free_vector_u32(pred_labels[ n ]); 00987 free_vector_d(pred_confs[ n ]); 00988 } 00989 00990 if (get_num_unhandled_errors() > 0) 00991 { 00992 print_error_msg_exit("haplo-predict", "Unhandled errors"); 00993 } 00994 00995 return EXIT_SUCCESS; 00996 }