Haplo Prediction
predict haplogroups
|
00001 /* 00002 * This work is licensed under a Creative Commons 00003 * Attribution-Noncommercial-Share Alike 3.0 United States License. 00004 * 00005 * http://creativecommons.org/licenses/by-nc-sa/3.0/us/ 00006 * 00007 * You are free: 00008 * 00009 * to Share - to copy, distribute, display, and perform the work 00010 * to Remix - to make derivative works 00011 * 00012 * Under the following conditions: 00013 * 00014 * Attribution. You must attribute the work in the manner specified by the 00015 * author or licensor (but not in any way that suggests that they endorse you 00016 * or your use of the work). 00017 * 00018 * Noncommercial. You may not use this work for commercial purposes. 00019 * 00020 * Share Alike. If you alter, transform, or build upon this work, you may 00021 * distribute the resulting work only under the same or similar license to 00022 * this one. 00023 * 00024 * For any reuse or distribution, you must make clear to others the license 00025 * terms of this work. The best way to do this is by including this header. 00026 * 00027 * Any of the above conditions can be waived if you get permission from the 00028 * copyright holder. 00029 * 00030 * Apart from the remix rights granted under this license, nothing in this 00031 * license impairs or restricts the author's moral rights. 00032 */ 00033 00034 00046 #include <config.h> 00047 00048 #include <stdlib.h> 00049 #include <stdio.h> 00050 #include <inttypes.h> 00051 #include <assert.h> 00052 #include <string.h> 00053 #include <errno.h> 00054 #include <math.h> 00055 00056 #include <libxml/tree.h> 00057 #include <libxml/parser.h> 00058 #include <libxml/valid.h> 00059 00060 #if defined HAPLO_HAVE_LIBSVM_H 00061 #include <libsvm.h> 00062 #elif defined HAPLO_HAVE_SVM_H 00063 #include <svm.h> 00064 #endif 00065 00066 #ifdef HAPLO_HAVE_DMALLOC 00067 #include <dmalloc.h> 00068 #endif 00069 00070 #include <jwsc/base/error.h> 00071 #include <jwsc/base/limits.h> 00072 #include <jwsc/base/file_io.h> 00073 #include <jwsc/vector/vector.h> 00074 #include <jwsc/vector/vector_io.h> 00075 #include <jwsc/vector/vector_math.h> 00076 #include <jwsc/matrix/matrix.h> 00077 #include <jwsc/matrix/matrix_io.h> 00078 #include <jwsc/matblock/matblock.h> 00079 #include <jwsc/matblock/matblock_io.h> 00080 #include <jwsc/stat/gmm.h> 00081 00082 #include "xml.h" 00083 #include "haplo_groups.h" 00084 #include "svm_tree.h" 00085 00086 00097 void train_svm_model 00098 ( 00099 SVM_model** model_out, 00100 const Vector_u32* labels, 00101 const Matrix_i32* markers, 00102 double cost, 00103 double gamma 00104 ) 00105 { 00106 uint32_t s, num_samples; 00107 uint32_t m, num_markers; 00108 int32_t marker_val; 00109 00110 SVM_model* model; 00111 struct svm_parameter* param; 00112 00113 if (*model_out != NULL) 00114 { 00115 free_svm_model(*model_out); 00116 } 00117 00118 assert(*model_out = malloc(sizeof(SVM_model))); 00119 model = *model_out; 00120 00121 num_samples = markers->num_rows; 00122 num_markers = markers->num_cols; 00123 00124 assert(model->prob = malloc(sizeof(struct svm_problem))); 00125 assert(model->prob->y = malloc(num_samples*sizeof(double))); 00126 assert(model->prob->x = malloc(num_samples*sizeof(struct svm_node*))); 00127 00128 model->prob->l = num_samples; 00129 for (s = 0; s < num_samples; s++) 00130 { 00131 model->prob->y[ s ] = labels->elts[ s ]; 00132 model->prob->x[ s ] = malloc((num_markers+1)*sizeof(struct svm_node)); 00133 assert(model->prob->x[ s ]); 00134 00135 for (m = 0; m < num_markers; m++) 00136 { 00137 model->prob->x[ s ][ m ].index = m+1; 00138 marker_val = markers->elts[ s ][ m ]; 00139 model->prob->x[ s ][ m ].value = marker_val; 00140 } 00141 model->prob->x[ s ][ m ].index = -1; 00142 } 00143 00144 assert(param = malloc(sizeof(struct svm_parameter))); 00145 00146 param->svm_type = C_SVC; 00147 param->kernel_type = RBF; 00148 param->gamma = gamma; 00149 param->C = cost; 00150 param->nr_weight = 0; 00151 param->probability = 1; 00152 param->shrinking = 1; /* svm-train default */ 00153 param->eps = 0.001; /* svm-train default */ 00154 param->cache_size = 40; /* svm-train default */ 00155 00156 assert(model->svm = svm_train(model->prob, param)); 00157 00158 free(param); 00159 } 00160 00161 00171 Error* predict_label_with_svm_model 00172 ( 00173 uint32_t* label_out, 00174 double** confidence_out, 00175 const struct svm_node* markers, 00176 const SVM_model* model 00177 ) 00178 { 00179 assert(*confidence_out); 00180 *label_out = svm_predict_probability(model->svm, markers, *confidence_out); 00181 00182 return NULL; 00183 } 00184 00185 00187 static void get_predicted_labels_confidence 00188 ( 00189 Vector_d** confidence_out, 00190 const Vector_u32* labels_v, 00191 const Matrix_d* confidence_matrix, 00192 const SVM_model* model 00193 ) 00194 { 00195 uint32_t s, num_samples; 00196 uint32_t l, num_labels; 00197 uint32_t label; 00198 int* labels; 00199 Vector_d* confidence; 00200 00201 num_samples = labels_v->num_elts; 00202 00203 create_zero_vector_d(confidence_out, num_samples); 00204 confidence = *confidence_out; 00205 00206 num_labels = (uint32_t)svm_get_nr_class(model->svm); 00207 assert(labels = malloc(num_labels*sizeof(uint32_t))); 00208 svm_get_labels(model->svm, labels); 00209 00210 for (s = 0; s < num_samples; s++) 00211 { 00212 label = labels_v->elts[ s ]; 00213 00214 for (l = 0; l < num_labels; l++) 00215 { 00216 if ((uint32_t)labels[ l ] == label) 00217 { 00218 confidence->elts[ s ] = confidence_matrix->elts[ s ][ l ]; 00219 } 00220 } 00221 } 00222 00223 free(labels); 00224 } 00225 00226 00240 Error* predict_labels_with_svm_model 00241 ( 00242 Vector_u32** labels_out, 00243 Vector_d** confidence_out, 00244 const Matrix_i32* markers, 00245 const SVM_model* model 00246 ) 00247 { 00248 uint32_t s, num_samples; 00249 uint32_t m, num_markers; 00250 00251 Vector_u32* labels; 00252 Matrix_d* confidence_matrix; 00253 00254 struct svm_node* markers_v; 00255 00256 num_samples = markers->num_rows; 00257 num_markers = markers->num_cols; 00258 00259 assert(markers_v = malloc((num_markers+1)*sizeof(struct svm_node))); 00260 00261 create_zero_vector_u32(labels_out, num_samples); 00262 labels = *labels_out; 00263 00264 confidence_matrix = NULL; 00265 create_zero_matrix_d(&confidence_matrix, num_samples, 00266 svm_get_nr_class(model->svm)); 00267 00268 for (s = 0; s < num_samples; s++) 00269 { 00270 for (m = 0; m < num_markers; m++) 00271 { 00272 markers_v[ m ].index = m+1; 00273 markers_v[ m ].value = markers->elts[ s ][ m ]; 00274 } 00275 markers_v[ m].index = -1; 00276 00277 predict_label_with_svm_model(&(labels->elts[ s ]), 00278 &(confidence_matrix->elts[ s ]), markers_v, model); 00279 } 00280 00281 get_predicted_labels_confidence(confidence_out, labels, confidence_matrix, 00282 model); 00283 00284 free(markers_v); 00285 free_matrix_d(confidence_matrix); 00286 00287 return NULL; 00288 } 00289 00290 00299 Error* read_svm_model(SVM_model** model_out, const char* fname) 00300 { 00301 int slen = 256; 00302 char str[slen]; 00303 00304 if (*model_out) 00305 { 00306 free_svm_model(*model_out); 00307 } 00308 00309 assert(*model_out = malloc(sizeof(SVM_model))); 00310 (*model_out)->prob = NULL; 00311 00312 if (!((*model_out)->svm = svm_load_model(fname))) 00313 { 00314 snprintf(str, slen, "%s: %s", fname, "Could not read model"); 00315 return JWSC_EARG(str); 00316 } 00317 00318 return NULL; 00319 } 00320 00321 00328 Error* write_svm_model(SVM_model* model, const char* fname) 00329 { 00330 int slen = 256; 00331 char str[slen]; 00332 00333 if (svm_save_model(fname, model->svm) < 0) 00334 { 00335 snprintf(str, slen, "%s: %s", fname, "Could not write model"); 00336 return JWSC_EARG(str); 00337 } 00338 00339 return NULL; 00340 } 00341 00342 00356 Error* write_svm_model_training_data 00357 ( 00358 const Vector_u32* labels, 00359 const Matrix_i32* markers, 00360 const char* fname 00361 ) 00362 { 00363 FILE* fp; 00364 uint32_t s, num_samples; 00365 uint32_t m, num_markers; 00366 int slen = 256; 00367 char str[slen]; 00368 00369 num_samples = markers->num_rows; 00370 num_markers = markers->num_cols; 00371 00372 if ((fp = fopen(fname, "w")) == NULL) 00373 { 00374 snprintf(str, slen, "%s: %s", fname, strerror(errno)); 00375 return JWSC_EIO(str); 00376 } 00377 00378 for (s = 0; s < num_samples; s++) 00379 { 00380 fprintf(fp, "%-4d", labels->elts[ s ]); 00381 for (m = 0; m < num_markers; m++) 00382 { 00383 fprintf(fp, " %2d:%-4d", m+1, markers->elts[ s ][ m ]); 00384 } 00385 fprintf(fp, "\n"); 00386 } 00387 00388 if (fclose(fp) != 0) 00389 { 00390 snprintf(str, slen, "%s: %s", fname, strerror(errno)); 00391 return JWSC_EIO(str); 00392 } 00393 00394 return NULL; 00395 } 00396 00397 00399 void free_svm_model(SVM_model* model) 00400 { 00401 uint32_t i; 00402 00403 if (!model) 00404 return; 00405 00406 svm_free_and_destroy_model(&(model->svm)); 00407 00408 if (model->prob) 00409 { 00410 for (i = 0; i < model->prob->l; i++ ) 00411 { 00412 free(model->prob->x[ i ]); 00413 } 00414 free(model->prob->x); 00415 free(model->prob->y); 00416 free(model->prob); 00417 } 00418 00419 free(model); 00420 } 00421 00422 00424 static Error* read_svm_xml_doc 00425 ( 00426 xmlDoc** xml_doc_out, 00427 const char* xml_fname, 00428 const char* dtd_fname 00429 ) 00430 { 00431 xmlParserCtxt* xml_parse_ctxt; 00432 xmlValidCtxt* xml_valid_ctxt; 00433 xmlDtd* xml_dtd; 00434 int slen = 256; 00435 char str[slen]; 00436 00437 assert(xml_parse_ctxt = xmlNewParserCtxt()); 00438 00439 if (!(*xml_doc_out = xmlCtxtReadFile(xml_parse_ctxt, xml_fname, NULL, 0))) 00440 { 00441 snprintf(str, slen, "%s: %s", xml_fname, "Could not parse file"); 00442 return JWSC_EARG(str); 00443 } 00444 00445 xmlFreeParserCtxt(xml_parse_ctxt); 00446 00447 if (dtd_fname) 00448 { 00449 assert(xml_valid_ctxt = xmlNewValidCtxt()); 00450 00451 if (!(xml_dtd = xmlParseDTD(NULL, (xmlChar*)dtd_fname))) 00452 { 00453 snprintf(str, slen, "%s: %s", dtd_fname, "Could not parse DTD"); 00454 return JWSC_EARG(str); 00455 } 00456 00457 if (!xmlValidateDtd(xml_valid_ctxt, *xml_doc_out, xml_dtd)) 00458 { 00459 snprintf(str, slen, "%s: %s", xml_fname, "XML file not valid"); 00460 return JWSC_EARG(str); 00461 } 00462 00463 xmlFreeValidCtxt(xml_valid_ctxt); 00464 xmlFreeDtd(xml_dtd); 00465 } 00466 00467 return NULL; 00468 } 00469 00470 static Error* create_svm_model_tree_from_xml_node 00471 ( 00472 SVM_model_tree** tree_out, 00473 const SVM_model_tree* parent, 00474 uint32_t parent_label, 00475 xmlNode* xml_node 00476 ); 00477 00479 static Error* create_svm_model_node_from_xml_node 00480 ( 00481 xmlNode* xml_node, 00482 SVM_model_node* svm_node, 00483 uint32_t i 00484 ) 00485 { 00486 uint32_t label; 00487 uint32_t altlabel; 00488 uint32_t num_altlabels; 00489 uint32_t g; 00490 uint32_t a; 00491 const char* fname; 00492 xmlNode* it; 00493 xmlNode* itt; 00494 Error* err; 00495 int slen = 256; 00496 char str[slen]; 00497 00498 g = 0; 00499 for (it = xml_node; it; it = it->next) 00500 { 00501 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "file")) 00502 { 00503 fname = (const char*) it->children->content; 00504 svm_node->model_fnames[ i ] = malloc((strlen(fname)+1) * 00505 sizeof(char)); 00506 strcpy(svm_node->model_fnames[ i ], fname); 00507 } 00508 else if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "cost")) 00509 { 00510 if (sscanf((char*)it->children->content, "%lf", 00511 &(svm_node->cost->elts[ i ])) != 1) 00512 { 00513 snprintf(str, slen, "%s: %s", fname, "Invalid cost"); 00514 return JWSC_EARG(str); 00515 } 00516 } 00517 else if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "gamma")) 00518 { 00519 if (sscanf((char*)it->children->content, "%lf", 00520 &(svm_node->gamma->elts[ i ])) != 1) 00521 { 00522 snprintf(str, slen, "%s: %s", fname, "Invalid gamma"); 00523 return JWSC_EARG(str); 00524 } 00525 } 00526 else if (it->type == XML_ELEMENT_NODE && 00527 (XMLStrEqual(it->name, "group-one") || 00528 XMLStrEqual(it->name, "group-all"))) 00529 { 00530 a = 0; 00531 num_altlabels = 0; 00532 for (itt = it->children; itt; itt = itt->next) 00533 { 00534 if (itt->type == XML_ELEMENT_NODE && 00535 XMLStrEqual(itt->name, "altlabel")) 00536 { 00537 num_altlabels++; 00538 } 00539 } 00540 if (num_altlabels) 00541 { 00542 create_vector_u32(&(svm_node->altlabels[ g ][ i ]), 00543 num_altlabels); 00544 } 00545 00546 for (itt = it->children; itt; itt = itt->next) 00547 { 00548 if (itt->type == XML_ELEMENT_NODE && 00549 XMLStrEqual(itt->name, "label")) 00550 { 00551 if ((err = lookup_haplo_group_index_from_label(&label, 00552 (const char*) itt->children->content))) 00553 { 00554 snprintf(str, slen, "%s: %s", fname, err->msg); 00555 return JWSC_EARG(str); 00556 } 00557 svm_node->labels[ g ]->elts[ i ] = label; 00558 } 00559 else if (itt->type == XML_ELEMENT_NODE && 00560 XMLStrEqual(itt->name, "altlabel")) 00561 { 00562 if ((err = lookup_haplo_group_index_from_label(&altlabel, 00563 (const char*) itt->children->content))) 00564 { 00565 snprintf(str, slen, "%s: %s", fname, err->msg); 00566 return JWSC_EARG(str); 00567 } 00568 svm_node->altlabels[ g ][ i ]->elts[ a++ ] = altlabel; 00569 } 00570 else if (itt->type == XML_ELEMENT_NODE && 00571 (XMLStrEqual(itt->name, "binary-model") || 00572 XMLStrEqual(itt->name, "one-vs-all-model") || 00573 XMLStrEqual(itt->name, "one-vs-one-model"))) 00574 { 00575 if ((err = create_svm_model_tree_from_xml_node( 00576 &(svm_node->subtrees[ g ][ i ]), svm_node, 00577 label, itt))) 00578 { 00579 snprintf(str, slen, "%s: %s", fname, err->msg); 00580 return JWSC_EARG(str); 00581 } 00582 break; 00583 } 00584 } 00585 g++; 00586 } 00587 } 00588 00589 return NULL; 00590 } 00591 00592 00599 static Error* create_svm_model_tree_from_xml_node 00600 ( 00601 SVM_model_tree** tree_out, 00602 const SVM_model_tree* parent, 00603 uint32_t parent_label, 00604 xmlNode* xml_node 00605 ) 00606 { 00607 SVM_model_tree* tree; 00608 uint32_t i; 00609 xmlNode* it; 00610 Error* err; 00611 00612 if (*tree_out) 00613 { 00614 free_svm_model_tree(*tree_out); 00615 } 00616 00617 assert(*tree_out = malloc(sizeof(SVM_model_tree))); 00618 tree = *tree_out; 00619 00620 tree->parent = parent; 00621 tree->parent_label = parent_label; 00622 tree->subtrees[0] = NULL; 00623 tree->subtrees[1] = NULL; 00624 tree->labels[0] = NULL; 00625 tree->labels[1] = NULL; 00626 tree->altlabels[0] = NULL; 00627 tree->altlabels[1] = NULL; 00628 tree->cost = NULL; 00629 tree->gamma = NULL; 00630 tree->num_models = 0; 00631 tree->models = NULL; 00632 tree->model_fnames = NULL; 00633 00634 assert(XMLStrEqual(xml_node->name, "binary-model") || 00635 XMLStrEqual(xml_node->name, "one-vs-all-model") || 00636 XMLStrEqual(xml_node->name, "one-vs-one-model")); 00637 00638 for (it = xml_node; it; it = it->next) 00639 { 00640 if (it->type == XML_ELEMENT_NODE && 00641 (XMLStrEqual(it->name, "binary-model") || 00642 XMLStrEqual(it->name, "one-vs-all-model") || 00643 XMLStrEqual(it->name, "one-vs-one-model"))) 00644 { 00645 tree->num_models++; 00646 } 00647 } 00648 00649 assert(tree->subtrees[0] = calloc(tree->num_models, sizeof(void*))); 00650 assert(tree->subtrees[1] = calloc(tree->num_models, sizeof(void*))); 00651 assert(tree->altlabels[0] = calloc(tree->num_models, sizeof(void*))); 00652 assert(tree->altlabels[1] = calloc(tree->num_models, sizeof(void*))); 00653 assert(tree->model_fnames = calloc(tree->num_models, sizeof(void*))); 00654 assert(tree->models = calloc(tree->num_models, sizeof(void*))); 00655 00656 create_vector_u32(&(tree->labels[0]), tree->num_models); 00657 create_vector_u32(&(tree->labels[1]), tree->num_models); 00658 create_vector_d(&(tree->cost), tree->num_models); 00659 create_vector_d(&(tree->gamma), tree->num_models); 00660 00661 i = 0; 00662 for (it = xml_node; it; it = it->next) 00663 { 00664 if (it->type == XML_ELEMENT_NODE && 00665 (XMLStrEqual(it->name, "binary-model") || 00666 XMLStrEqual(it->name, "one-vs-all-model") || 00667 XMLStrEqual(it->name, "one-vs-one-model"))) 00668 { 00669 if ((err = create_svm_model_node_from_xml_node(it->children, tree, 00670 i++))) 00671 { 00672 return err; 00673 } 00674 } 00675 } 00676 assert(i == tree->num_models); 00677 00678 return NULL; 00679 } 00680 00681 00686 static Error* create_model_training_data 00687 ( 00688 Vector_u32** train_labels_out, 00689 Matrix_i32** train_markers_out, 00690 const Vector_u32* data_labels, 00691 const Matrix_i32* data_markers, 00692 uint32_t* model_labels, 00693 const Vector_u32** model_altlabels 00694 ) 00695 { 00696 uint8_t b; 00697 uint32_t i, j, k; 00698 uint32_t n; 00699 uint32_t m[2] = {0}; 00700 const char* label_str; 00701 Vector_u32* train_labels = NULL; 00702 Matrix_i32* train_markers = NULL; 00703 int slen = 256; 00704 char str[slen]; 00705 00706 copy_vector_u32(&train_labels, data_labels); 00707 copy_matrix_i32(&train_markers, data_markers); 00708 00709 n = 0; 00710 for (i = 0; i < data_labels->num_elts; i++) 00711 { 00712 for (j = 0; j < 2; j++) 00713 { 00714 if (is_ancestor(data_labels->elts[ i ], model_labels[ j ])) 00715 { 00716 train_labels->elts[ n ] = model_labels[ j ]; 00717 copy_matrix_block_into_matrix_i32(train_markers, n, 0, 00718 data_markers, i, 0, 1, data_markers->num_cols); 00719 n++; 00720 m[j]++; 00721 break; 00722 } 00723 else if (model_altlabels[ j ]) 00724 { 00725 for (k = 0; k < model_altlabels[ j ]->num_elts; k++) 00726 { 00727 if ((b = is_ancestor(data_labels->elts[ i ], 00728 model_altlabels[ j ]->elts[ k ]))) 00729 { 00730 train_labels->elts[ n ] = model_labels[ j ]; 00731 copy_matrix_block_into_matrix_i32(train_markers, n, 0, 00732 data_markers, i, 0, 1, data_markers->num_cols); 00733 n++; 00734 m[j]++; 00735 break; 00736 } 00737 } 00738 if (b) 00739 { 00740 break; 00741 } 00742 } 00743 } 00744 } 00745 00746 assert(n == (m[0] + m[1])); 00747 00748 if (!(m[j=0]) || !(m[j=1])) 00749 { 00750 lookup_haplo_group_label_from_index(&label_str, model_labels[j]); 00751 snprintf(str, slen, "No data for model label %s", label_str); 00752 return JWSC_EARG(str); 00753 } 00754 00755 copy_vector_section_u32(train_labels_out, train_labels, 0, n); 00756 copy_matrix_block_i32(train_markers_out, train_markers, 0, 0, n, 00757 train_markers->num_cols); 00758 00759 free_vector_u32(train_labels); 00760 free_matrix_i32(train_markers); 00761 00762 return NULL; 00763 } 00764 00765 00767 static Error* train_svm_model_node 00768 ( 00769 SVM_model_node* node, 00770 const Vector_u32* labels, 00771 const Matrix_i32* markers 00772 ) 00773 { 00774 uint32_t i, j; 00775 uint32_t label[2]; 00776 const Vector_u32* altlabel[2]; 00777 Vector_u32* node_labels = NULL; 00778 Matrix_i32* node_markers = NULL; 00779 Error* err; 00780 int slen = 256; 00781 char str[slen]; 00782 00783 for (i = 0; i < node->num_models; i++) 00784 { 00785 label[0] = node->labels[0]->elts[ i ]; 00786 label[1] = node->labels[1]->elts[ i ]; 00787 00788 altlabel[0] = node->altlabels[0][ i ]; 00789 altlabel[1] = node->altlabels[1][ i ]; 00790 00791 if ((err = create_model_training_data(&node_labels, &node_markers, 00792 labels, markers, label, altlabel))) 00793 { 00794 snprintf(str, slen, "%s: %s", node->model_fnames[i], err->msg); 00795 return JWSC_EARG(str); 00796 } 00797 00798 train_svm_model(&(node->models[ i ]), node_labels, node_markers, 00799 node->cost->elts[ i ], node->gamma->elts[ i ]); 00800 00801 for (j = 0; j < 2; j++) 00802 { 00803 if (node->subtrees[ j ][ i ]) 00804 { 00805 if ((err = train_svm_model_node(node->subtrees[ j ][ i ], 00806 labels, markers))) 00807 { 00808 return err; 00809 } 00810 } 00811 } 00812 } 00813 00814 free_vector_u32(node_labels); 00815 free_matrix_i32(node_markers); 00816 00817 return NULL; 00818 } 00819 00820 00828 Error* train_svm_model_tree 00829 ( 00830 SVM_model_tree** tree_out, 00831 const Vector_u32* labels, 00832 const Matrix_i32* markers, 00833 const char* tree_xml_fname, 00834 const char* tree_dtd_fname 00835 ) 00836 { 00837 xmlDoc* xml_doc; 00838 xmlNode* xml_node = NULL; 00839 xmlNode* it; 00840 Error* err; 00841 int slen = 256; 00842 char str[slen]; 00843 00844 assert(tree_xml_fname); 00845 00846 if ((err = read_svm_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname))) 00847 { 00848 return err; 00849 } 00850 00851 for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next) 00852 { 00853 if (it->type == XML_ELEMENT_NODE && 00854 (XMLStrEqual(it->name, "binary-model") || 00855 XMLStrEqual(it->name, "one-vs-all-model") || 00856 XMLStrEqual(it->name, "one-vs-one-model"))) 00857 { 00858 xml_node = it; 00859 break; 00860 } 00861 } 00862 00863 assert(xml_node); 00864 00865 if ((err = create_svm_model_tree_from_xml_node(tree_out, NULL, 0, 00866 xml_node))) 00867 { 00868 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 00869 return JWSC_EARG(str); 00870 } 00871 00872 if ((err = train_svm_model_node(*tree_out, labels, markers))) 00873 { 00874 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 00875 return JWSC_EARG(str); 00876 } 00877 00878 return NULL; 00879 } 00880 00881 00883 static double get_predicted_label_confidence 00884 ( 00885 uint32_t label, 00886 double* confs, 00887 const SVM_model* model 00888 ) 00889 { 00890 uint32_t l, num_labels; 00891 int* labels; 00892 double conf; 00893 00894 num_labels = (uint32_t)svm_get_nr_class(model->svm); 00895 assert(labels = malloc(num_labels*sizeof(int))); 00896 svm_get_labels(model->svm, labels); 00897 00898 for (l = 0; l < num_labels; l++) 00899 { 00900 if (labels[ l ] == label) 00901 { 00902 conf = confs[ l ]; 00903 } 00904 } 00905 00906 free(labels); 00907 00908 return conf; 00909 } 00910 00911 00913 static uint32_t get_best_one_against_all_model 00914 ( 00915 const SVM_model_tree* tree, 00916 Vector_u32* labels, 00917 Vector_d* confs 00918 ) 00919 { 00920 uint32_t label, other_label; 00921 uint32_t num_models; 00922 uint32_t m; 00923 uint32_t num_labels; 00924 int32_t best_model = -1; 00925 double max_conf = 0; 00926 double min_conf = 1; 00927 double label_conf; 00928 double normalizing_c = 0; 00929 00930 Vector_d* label_conf_sum = NULL; 00931 00932 num_models = tree->num_models; 00933 num_labels = get_num_haplo_groups(); 00934 00935 assert(num_models > 1); 00936 00937 create_init_vector_d(&label_conf_sum, num_labels, 0); 00938 00939 /* Use the sum the confidence values for each predicted label to decide 00940 * which is best. If there is a tie, the first occuring label is choosen. */ 00941 for (m = 0; m < num_models; m++) 00942 { 00943 label = labels->elts[ m ]; 00944 label_conf = confs->elts[ m ]; 00945 00946 if (tree->labels[0]->elts[ m ] == label) 00947 { 00948 label_conf_sum->elts[ label ] += label_conf; 00949 other_label = tree->labels[1]->elts[ m ]; 00950 } 00951 else 00952 { 00953 label_conf_sum->elts[ tree->labels[0]->elts[ m ] ] += 1-label_conf; 00954 other_label = tree->labels[0]->elts[ m ]; 00955 } 00956 } 00957 00958 for (m = 0; m < num_models; m++) 00959 { 00960 label = labels->elts[ m ]; 00961 label_conf = label_conf_sum->elts[ label ]; 00962 00963 if ((label == tree->labels[0]->elts[ m ]) && (max_conf < label_conf)) 00964 { 00965 max_conf = label_conf; 00966 best_model = m; 00967 } 00968 } 00969 00970 /* It's possible that nothing was choosen because the 'other_label' was 00971 * selected in all groups. In this case, choose the predicted label whose 00972 * model had the least confidence in its 'other_label' prediction. */ 00973 if (best_model < 0) 00974 { 00975 for (m = 0; m < num_models; m++) 00976 { 00977 label = labels->elts[ m ]; 00978 assert(label == tree->labels[1]->elts[ m ]); 00979 label_conf = confs->elts[ m ]; 00980 00981 if (min_conf > label_conf) 00982 { 00983 min_conf = label_conf; 00984 best_model = m; 00985 } 00986 } 00987 00988 /* Hack the predicted_labels and confidence data structures to reflect 00989 * that 'other_label' was not choosen. */ 00990 labels->elts[ best_model ] = tree->labels[0]->elts[ best_model ]; 00991 00992 label_conf = confs->elts[ best_model ]; 00993 confs->elts[ best_model ] = 1.0 - label_conf; 00994 } 00995 00996 /* Need to normalize the confidence value for the best label prediction so 00997 * that it is a probability. At this level in the tree, one of the labels 00998 * will be chosen, so the confidence value for these labels needs to sum 00999 * to one. */ 01000 for (label = 0; label < num_labels; label++) 01001 { 01002 normalizing_c += label_conf_sum->elts[ label ]; 01003 } 01004 assert(normalizing_c > 1.0e-16); 01005 confs->elts[ best_model ] /= normalizing_c; 01006 01007 free_vector_d(label_conf_sum); 01008 01009 return best_model; 01010 } 01011 01012 01014 static uint32_t get_best_one_against_one_model 01015 ( 01016 const SVM_model_tree* tree, 01017 Vector_u32* labels, 01018 Vector_d* confs 01019 ) 01020 { 01021 uint32_t m, num_models; 01022 uint32_t l, num_labels; 01023 uint32_t best_model; 01024 uint32_t max; 01025 uint32_t the_winner; 01026 uint32_t num_winners; 01027 double max_conf; 01028 01029 Vector_u32* winners = NULL; 01030 01031 num_models = tree->num_models; 01032 num_labels = get_num_haplo_groups(); 01033 01034 assert(num_models > 1); 01035 01036 create_init_vector_u32(&winners, num_labels, 0); 01037 01038 for (m = 0; m < num_models; m++) 01039 { 01040 winners->elts[ labels->elts[ m ] ]++; 01041 } 01042 01043 max = 0; 01044 the_winner = 0; 01045 num_winners = 0; 01046 01047 for (l = 0; l < num_labels; l++) 01048 { 01049 if (winners->elts[ l ] > max) 01050 { 01051 the_winner = l; 01052 max = winners->elts[ l ]; 01053 } 01054 } 01055 for (l = 0; l < num_labels; l++) 01056 { 01057 if (max == winners->elts[ l ]) 01058 { 01059 num_winners++; 01060 } 01061 } 01062 01063 assert(max > 0); 01064 01065 max_conf = 0; 01066 01067 for (m = 0; m < num_models; m++) 01068 { 01069 if (labels->elts[ m ] == the_winner && confs->elts[ m ] > max_conf) 01070 { 01071 best_model = m; 01072 } 01073 } 01074 01075 free_vector_u32(winners); 01076 01077 return best_model; 01078 } 01079 01080 01082 static Error* recursively_predict_label_in_model_tree 01083 ( 01084 uint32_t* label_out, 01085 double* confidence_out, 01086 const SVM_model_tree* tree, 01087 const struct svm_node* markers 01088 ) 01089 { 01090 uint32_t i; 01091 uint32_t m, num_models; 01092 uint32_t subtree_label; 01093 Error* err; 01094 01095 Vector_u32* labels = NULL; 01096 Vector_d* confs = NULL; 01097 Vector_d* v = NULL; 01098 01099 num_models = tree->num_models; 01100 01101 create_vector_u32(&labels, num_models); 01102 create_vector_d(&confs, num_models); 01103 01104 for (m = 0; m < num_models; m++) 01105 { 01106 create_zero_vector_d(&v, svm_get_nr_class(tree->models[ m ]->svm)); 01107 01108 if ((err = predict_label_with_svm_model(&(labels->elts[ m ]), 01109 &(v->elts), markers, tree->models[ m ]))) 01110 { 01111 return err; 01112 } 01113 01114 confs->elts[ m ] = get_predicted_label_confidence(labels->elts[ m ], 01115 v->elts, tree->models[ m ]); 01116 } 01117 free_vector_d(v); 01118 01119 // If there are multiple models in the node, get the prediction from 01120 // the best one. 01121 m = (num_models > 1) ? get_best_one_against_all_model(tree, labels, confs) 01122 : 0; 01123 01124 *label_out = labels->elts[ m ]; 01125 *confidence_out *= confs->elts[ m ]; 01126 01127 free_vector_u32(labels); 01128 free_vector_d(confs); 01129 01130 // If the best model has a subtree, recursively predict in the subtree. 01131 for (m = 0; m < num_models; m++) 01132 { 01133 for (i = 0; i < 2; i++) 01134 { 01135 if (!(tree->subtrees[ i ][ m ])) 01136 continue; 01137 01138 subtree_label = tree->subtrees[ i ][ m ]->parent_label; 01139 01140 if (subtree_label == *label_out) 01141 { 01142 if ((err = recursively_predict_label_in_model_tree(label_out, 01143 confidence_out, tree->subtrees[ i ][ m ], 01144 markers))) 01145 { 01146 return err; 01147 } 01148 } 01149 } 01150 } 01151 01152 return NULL; 01153 } 01154 01155 01169 Error* predict_labels_with_svm_model_tree 01170 ( 01171 Vector_u32** labels_out, 01172 Vector_d** confidence_out, 01173 const Matrix_i32* markers, 01174 const SVM_model_tree* tree 01175 ) 01176 { 01177 uint32_t s, num_samples; 01178 uint32_t m, num_markers; 01179 01180 Vector_u32* labels; 01181 Vector_d* confidence; 01182 struct svm_node* markers_v; 01183 Error* e; 01184 01185 num_samples = markers->num_rows; 01186 num_markers = markers->num_cols; 01187 01188 create_vector_u32(labels_out, num_samples); 01189 labels = *labels_out; 01190 01191 create_init_vector_d(confidence_out, num_samples, 1.0); 01192 confidence = *confidence_out; 01193 01194 markers_v = NULL; 01195 assert(markers_v = malloc((num_markers+1)*sizeof(struct svm_node))); 01196 01197 for (s = 0; s < num_samples; s++) 01198 { 01199 for (m = 0; m < num_markers; m ++) 01200 { 01201 markers_v[ m ].index = m +1; 01202 markers_v[ m ].value = markers->elts[ s ][ m ]; 01203 } 01204 markers_v[ m ].index = -1; 01205 01206 if ((e = recursively_predict_label_in_model_tree( 01207 &(labels->elts[ s ]), 01208 &(confidence->elts[ s ]), tree, markers_v)) 01209 != NULL) 01210 { 01211 return e; 01212 } 01213 } 01214 01215 free(markers_v); 01216 01217 return NULL; 01218 } 01219 01220 01222 static Error* read_svm_model_node 01223 ( 01224 SVM_model_node* node, 01225 const char* model_dirname 01226 ) 01227 { 01228 uint32_t i, j; 01229 char buf[1024] = {0}; 01230 Error* err; 01231 01232 for (i = 0; i < node->num_models; i++) 01233 { 01234 snprintf(buf, 1024, "%s/%s", model_dirname, node->model_fnames[ i ]); 01235 if ((err = read_svm_model(&(node->models[ i ]), buf))) 01236 { 01237 return err; 01238 } 01239 01240 for (j = 0; j < 2; j++) 01241 { 01242 if (node->subtrees[ j ][ i ]) 01243 { 01244 if ((err = read_svm_model_node(node->subtrees[ j ][ i ], 01245 model_dirname))) 01246 { 01247 return err; 01248 } 01249 } 01250 } 01251 } 01252 01253 return NULL; 01254 } 01255 01256 01263 Error* read_svm_model_tree 01264 ( 01265 SVM_model_tree** tree_out, 01266 const char* tree_xml_fname, 01267 const char* tree_dtd_fname, 01268 const char* model_dirname 01269 ) 01270 { 01271 xmlDoc* xml_doc; 01272 xmlNode* xml_node = NULL; 01273 xmlNode* it; 01274 Error* err; 01275 int slen = 256; 01276 char str[slen]; 01277 01278 assert(tree_xml_fname && model_dirname); 01279 01280 if ((err = read_svm_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname))) 01281 { 01282 return err; 01283 } 01284 01285 for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next) 01286 { 01287 if (it->type == XML_ELEMENT_NODE && 01288 (XMLStrEqual(it->name, "binary-model") || 01289 XMLStrEqual(it->name, "one-vs-all-model") || 01290 XMLStrEqual(it->name, "one-vs-one-model"))) 01291 { 01292 xml_node = it; 01293 break; 01294 } 01295 } 01296 01297 assert(xml_node); 01298 01299 if ((err = create_svm_model_tree_from_xml_node(tree_out, NULL, 0, 01300 xml_node))) 01301 { 01302 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 01303 return JWSC_EARG(str); 01304 } 01305 01306 if ((err = read_svm_model_node(*tree_out, model_dirname))) 01307 { 01308 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 01309 return JWSC_EARG(str); 01310 } 01311 01312 return NULL; 01313 } 01314 01315 01320 Error* write_svm_model_tree 01321 ( 01322 const SVM_model_tree* tree, 01323 const char* model_dirname 01324 ) 01325 { 01326 uint32_t i, j; 01327 char buf[1024] = {0}; 01328 Error* err; 01329 int slen = 256; 01330 char str[slen]; 01331 01332 for (i = 0; i < tree->num_models; i++) 01333 { 01334 snprintf(buf, 1024, "%s/%s", model_dirname, tree->model_fnames[ i ]); 01335 if ((err = write_svm_model(tree->models[ i ], buf))) 01336 { 01337 return err; 01338 } 01339 01340 for (j = 0; j < 2; j++) 01341 { 01342 if (tree->subtrees[ j ][ i ]) 01343 { 01344 if ((err = write_svm_model_tree(tree->subtrees[ j ][ i ], 01345 model_dirname))) 01346 { 01347 snprintf(str, slen, "%s: %s", tree->model_fnames[ i ], 01348 err->msg); 01349 return JWSC_EARG(str); 01350 } 01351 } 01352 } 01353 } 01354 01355 return NULL; 01356 } 01357 01358 01363 static Error* write_svm_model_node_training_data 01364 ( 01365 SVM_model_node* node, 01366 const Vector_u32* labels, 01367 const Matrix_i32* markers, 01368 const char* data_dirname 01369 ) 01370 { 01371 uint32_t i, j; 01372 uint32_t label[2]; 01373 char buf[1024] = {0}; 01374 const Vector_u32* altlabel[2]; 01375 Vector_u32* node_labels = NULL; 01376 Matrix_i32* node_markers = NULL; 01377 Error* err; 01378 int slen = 256; 01379 char str[slen]; 01380 01381 for (i = 0; i < node->num_models; i++) 01382 { 01383 label[0] = node->labels[0]->elts[ i ]; 01384 label[1] = node->labels[1]->elts[ i ]; 01385 01386 altlabel[0] = node->altlabels[0][ i ]; 01387 altlabel[1] = node->altlabels[1][ i ]; 01388 01389 if ((err = create_model_training_data(&node_labels, &node_markers, 01390 labels, markers, label, altlabel))) 01391 { 01392 snprintf(str, slen, "%s: %s", node->model_fnames[i], err->msg); 01393 return JWSC_EARG(str); 01394 } 01395 01396 snprintf(buf, 1024, "%s/%s", data_dirname, node->model_fnames[ i ]); 01397 01398 if ((err = write_svm_model_training_data(node_labels, node_markers, 01399 buf))) 01400 { 01401 return err; 01402 } 01403 01404 for (j = 0; j < 2; j++) 01405 { 01406 if (node->subtrees[ j ][ i ]) 01407 { 01408 if ((err = write_svm_model_node_training_data( 01409 node->subtrees[ j ][ i ], labels, markers, 01410 data_dirname))) 01411 { 01412 return err; 01413 } 01414 } 01415 } 01416 } 01417 01418 free_vector_u32(node_labels); 01419 free_matrix_i32(node_markers); 01420 01421 return NULL; 01422 } 01423 01424 01435 Error* write_svm_model_tree_training_data 01436 ( 01437 const Vector_u32* labels, 01438 const Matrix_i32* markers, 01439 const char* tree_xml_fname, 01440 const char* tree_dtd_fname, 01441 const char* data_dirname 01442 ) 01443 { 01444 SVM_model_tree* tree = NULL; 01445 xmlDoc* xml_doc; 01446 xmlNode* xml_node = NULL; 01447 xmlNode* it; 01448 Error* err; 01449 int slen = 256; 01450 char str[slen]; 01451 01452 assert(tree_xml_fname); 01453 01454 if ((err = read_svm_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname))) 01455 { 01456 return err; 01457 } 01458 01459 for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next) 01460 { 01461 if (it->type == XML_ELEMENT_NODE && 01462 (XMLStrEqual(it->name, "binary-model") || 01463 XMLStrEqual(it->name, "one-vs-one-model") || 01464 XMLStrEqual(it->name, "one-vs-all-model"))) 01465 { 01466 xml_node = it; 01467 break; 01468 } 01469 } 01470 01471 assert(xml_node); 01472 01473 if ((err = create_svm_model_tree_from_xml_node(&tree, NULL, 0, xml_node))) 01474 { 01475 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 01476 return JWSC_EARG(str); 01477 } 01478 01479 if ((err = write_svm_model_node_training_data(tree, labels, markers, 01480 data_dirname))) 01481 { 01482 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 01483 return JWSC_EARG(str); 01484 } 01485 01486 free_svm_model_tree(tree); 01487 01488 return NULL; 01489 } 01490 01491 01493 void free_svm_model_tree(SVM_model_tree* tree) 01494 { 01495 uint32_t i; 01496 01497 if (!tree) 01498 return; 01499 01500 for (i = 0; i < tree->num_models; i++) 01501 { 01502 free_svm_model_tree(tree->subtrees[0][ i ]); 01503 free_svm_model_tree(tree->subtrees[1][ i ]); 01504 free_vector_u32(tree->altlabels[0][ i ]); 01505 free_vector_u32(tree->altlabels[1][ i ]); 01506 free_svm_model(tree->models[ i ]); 01507 free(tree->model_fnames[ i ]); 01508 } 01509 01510 free(tree->subtrees[0]); 01511 free(tree->subtrees[1]); 01512 free_vector_u32(tree->labels[0]); 01513 free_vector_u32(tree->labels[1]); 01514 free(tree->altlabels[0]); 01515 free(tree->altlabels[1]); 01516 free_vector_d(tree->cost); 01517 free_vector_d(tree->gamma); 01518 free(tree->models); 01519 free(tree->model_fnames); 01520 free(tree); 01521 }