Haplo Prediction
predict haplogroups
|
00001 /* 00002 * This work is licensed under a Creative Commons 00003 * Attribution-Noncommercial-Share Alike 3.0 United States License. 00004 * 00005 * http://creativecommons.org/licenses/by-nc-sa/3.0/us/ 00006 * 00007 * You are free: 00008 * 00009 * to Share - to copy, distribute, display, and perform the work 00010 * to Remix - to make derivative works 00011 * 00012 * Under the following conditions: 00013 * 00014 * Attribution. You must attribute the work in the manner specified by the 00015 * author or licensor (but not in any way that suggests that they endorse you 00016 * or your use of the work). 00017 * 00018 * Noncommercial. You may not use this work for commercial purposes. 00019 * 00020 * Share Alike. If you alter, transform, or build upon this work, you may 00021 * distribute the resulting work only under the same or similar license to 00022 * this one. 00023 * 00024 * For any reuse or distribution, you must make clear to others the license 00025 * terms of this work. The best way to do this is by including this header. 00026 * 00027 * Any of the above conditions can be waived if you get permission from the 00028 * copyright holder. 00029 * 00030 * Apart from the remix rights granted under this license, nothing in this 00031 * license impairs or restricts the author's moral rights. 00032 */ 00033 00034 00046 #include <config.h> 00047 00048 #include <stdlib.h> 00049 #include <stdio.h> 00050 #include <inttypes.h> 00051 #include <assert.h> 00052 #include <string.h> 00053 #include <errno.h> 00054 #include <math.h> 00055 00056 #include <libxml/tree.h> 00057 #include <libxml/parser.h> 00058 #include <libxml/valid.h> 00059 00060 #ifdef HAPLO_HAVE_DMALLOC 00061 #include <dmalloc.h> 00062 #endif 00063 00064 #include <jwsc/base/error.h> 00065 #include <jwsc/base/limits.h> 00066 #include <jwsc/base/file_io.h> 00067 #include <jwsc/vector/vector.h> 00068 #include <jwsc/vector/vector_io.h> 00069 #include <jwsc/vector/vector_math.h> 00070 #include <jwsc/matrix/matrix.h> 00071 #include <jwsc/matrix/matrix_io.h> 00072 #include <jwsc/prob/pdf.h> 00073 00074 #include "xml.h" 00075 #include "haplo_groups.h" 00076 #include "nb_gauss.h" 00077 00078 00088 void train_nb_gauss_model 00089 ( 00090 NB_gauss_model** model_out, 00091 const Vector_u32* labels, 00092 const Matrix_i32* markers, 00093 const Vector_d* label_priors 00094 ) 00095 { 00096 uint32_t num_labels; 00097 uint32_t label; 00098 uint32_t num_samples; 00099 uint32_t sample; 00100 uint32_t num_markers; 00101 uint32_t marker; 00102 double marker_val; 00103 00104 Matrix_d* mu; 00105 Matrix_d* sigma; 00106 Vector_d* priors; 00107 Matrix_u32* marker_counts; 00108 00109 assert(markers->num_rows == labels->num_elts); 00110 00111 num_labels = get_num_haplo_groups(); 00112 00113 if (*model_out != NULL) 00114 { 00115 mu = (*model_out)->mu; 00116 sigma = (*model_out)->sigma; 00117 priors = (*model_out)->priors; 00118 } 00119 else 00120 { 00121 mu = NULL; 00122 sigma = NULL; 00123 priors = NULL; 00124 } 00125 00126 num_markers = markers->num_cols; 00127 num_samples = markers->num_rows; 00128 00129 create_zero_matrix_d(&mu, num_labels, num_markers); 00130 create_init_matrix_d(&sigma, num_labels, num_markers, 0.1); 00131 00132 marker_counts = NULL; 00133 create_zero_matrix_u32(&marker_counts, num_labels, num_markers); 00134 00135 for (sample = 0; sample < num_samples; sample++) 00136 { 00137 label = labels->elts[ sample ]; 00138 00139 for (marker = 0; marker < num_markers; marker++) 00140 { 00141 marker_val = markers->elts[ sample ][ marker ]; 00142 00143 if (marker_val > 0) 00144 { 00145 marker_counts->elts[ label ][ marker ]++; 00146 mu->elts[ label ][ marker ] += marker_val; 00147 sigma->elts[ label ][ marker ] += marker_val*marker_val; 00148 } 00149 } 00150 } 00151 for (label = 0; label < num_labels; label++) 00152 { 00153 for (marker = 0; marker < num_markers; marker++) 00154 { 00155 if (marker_counts->elts[ label ][ marker ] > 0) 00156 { 00157 mu->elts[ label ][ marker ] /= 00158 marker_counts->elts[ label ][ marker ]; 00159 00160 sigma->elts[ label ][ marker ] /= 00161 marker_counts->elts[ label ][ marker ]; 00162 00163 sigma->elts[ label ][ marker ] -= 00164 mu->elts[ label ][ marker ] * mu->elts[ label ][ marker ]; 00165 00166 sigma->elts[ label ][ marker ] = 00167 sqrt(sigma->elts[ label ][ marker ]); 00168 } 00169 } 00170 } 00171 00172 if (label_priors == NULL) 00173 { 00174 create_zero_vector_d(&priors, num_labels); 00175 00176 for (sample = 0; sample < num_samples; sample++) 00177 { 00178 label = labels->elts[ sample ]; 00179 priors->elts[ label ]++; 00180 } 00181 normalize_vector_sum_d(&priors, priors); 00182 } 00183 else 00184 { 00185 normalize_vector_sum_d(&priors, label_priors); 00186 } 00187 00188 if (*model_out == NULL) 00189 { 00190 assert(*model_out = malloc(sizeof(NB_gauss_model))); 00191 } 00192 (*model_out)->num_markers = num_markers; 00193 (*model_out)->mu = mu; 00194 (*model_out)->sigma = sigma; 00195 (*model_out)->priors = priors; 00196 00197 free_matrix_u32(marker_counts); 00198 } 00199 00200 00212 Error* predict_label_with_nb_gauss_model 00213 ( 00214 uint32_t* label_out, 00215 double* confidence_out, 00216 const Vector_i32* markers, 00217 const NB_gauss_model* model, 00218 uint32_t order 00219 ) 00220 { 00221 uint32_t num_markers; 00222 uint32_t marker; 00223 uint32_t num_labels; 00224 uint32_t label; 00225 uint32_t best_label; 00226 uint32_t i; 00227 double marker_val; 00228 double mu, sigma; 00229 double log_map; 00230 double log_ll; 00231 double log_prior; 00232 double log_posterior; 00233 double posterior_sum; 00234 00235 Vector_u32* best_labels = NULL; 00236 Vector_d* best_labels_conf = NULL; 00237 00238 num_labels = get_num_haplo_groups(); 00239 num_markers = markers->num_elts; 00240 00241 log_map = log(JWSC_MIN_LOG_ARG); 00242 posterior_sum = 0; 00243 00244 create_init_vector_u32(&best_labels, num_labels, 0); 00245 create_init_vector_d(&best_labels_conf, num_labels, 0); 00246 00247 for (label = 0; label < num_labels; label++) 00248 { 00249 log_ll = 0; 00250 00251 for (marker = 0; marker < num_markers; marker++) 00252 { 00253 marker_val = markers->elts[ marker ]; 00254 00255 /* Ignore non-positive marker values. */ 00256 if (marker_val > 0) 00257 { 00258 mu = model->mu->elts[ label ][ marker ]; 00259 sigma = model->sigma->elts[ label ][ marker ]; 00260 00261 log_ll += log_gaussian_pdf_d(mu, sigma, marker_val); 00262 } 00263 } 00264 00265 if (model->priors->elts[ label ] < JWSC_MIN_LOG_ARG) 00266 { 00267 log_prior = log(JWSC_MIN_LOG_ARG); 00268 } 00269 else 00270 { 00271 log_prior = log(model->priors->elts[ label ]); 00272 } 00273 00274 log_posterior = log_ll + log_prior; 00275 posterior_sum += exp(log_posterior); 00276 00277 if (log_posterior > log_map) 00278 { 00279 log_map = log_posterior; 00280 best_label = label; 00281 for (i = num_labels-1; i > 0; i--) 00282 { 00283 best_labels->elts[ i ] = best_labels->elts[ i - 1]; 00284 best_labels_conf->elts[ i ] = best_labels_conf->elts[ i - 1]; 00285 } 00286 best_labels->elts[ 0 ] = best_label; 00287 best_labels_conf->elts[ 0 ] = log_map; 00288 } 00289 } 00290 00291 assert(order < best_labels->num_elts); 00292 00293 *label_out = best_labels->elts[ order ]; 00294 00295 if (posterior_sum > 0) 00296 { 00297 assert(finite(*confidence_out = exp(best_labels_conf->elts[order]) / 00298 (double)posterior_sum)); 00299 } 00300 else 00301 { 00302 *confidence_out = 0; 00303 } 00304 00305 free_vector_u32(best_labels); 00306 free_vector_d(best_labels_conf); 00307 00308 return NULL; 00309 } 00310 00311 00325 Error* predict_labels_with_nb_gauss_model 00326 ( 00327 Vector_u32** labels_out, 00328 Vector_d** confidence_out, 00329 const Matrix_i32* markers, 00330 const NB_gauss_model* model 00331 ) 00332 { 00333 uint32_t num_samples; 00334 uint32_t sample; 00335 uint32_t num_markers; 00336 uint32_t marker; 00337 00338 Vector_u32* labels; 00339 Vector_d* confidence; 00340 Vector_i32* markers_v; 00341 Error* e; 00342 00343 num_samples = markers->num_rows; 00344 num_markers = markers->num_cols; 00345 00346 create_vector_u32(labels_out, num_samples); 00347 labels = *labels_out; 00348 00349 create_vector_d(confidence_out, num_samples); 00350 confidence = *confidence_out; 00351 00352 markers_v = NULL; 00353 create_vector_i32(&markers_v, num_markers); 00354 00355 for (sample = 0; sample < num_samples; sample++) 00356 { 00357 for (marker = 0; marker < num_markers; marker++) 00358 { 00359 markers_v->elts[ marker ] = markers->elts[ sample ][marker]; 00360 } 00361 00362 if ((e = predict_label_with_nb_gauss_model(&(labels->elts[ sample ]), 00363 &(confidence->elts[ sample ]), markers_v, model, 0)) != NULL) 00364 { 00365 return e; 00366 } 00367 } 00368 00369 free_vector_i32(markers_v); 00370 00371 return NULL; 00372 } 00373 00374 00383 Error* read_nb_gauss_model(NB_gauss_model** model_out, const char* fname) 00384 { 00385 FILE* fp; 00386 NB_gauss_model* model; 00387 Error* e; 00388 int slen = 256; 00389 char str[slen]; 00390 00391 if ((fp = fopen(fname, "r")) == NULL) 00392 { 00393 snprintf(str, slen, "%s: %s", fname, strerror(errno)); 00394 return JWSC_EIO(str); 00395 } 00396 00397 if (*model_out == NULL) 00398 { 00399 assert((*model_out = malloc(sizeof(NB_gauss_model))) != NULL); 00400 model = *model_out; 00401 model->mu = NULL; 00402 model->sigma = NULL; 00403 model->priors = NULL; 00404 } 00405 00406 if (fscanf(fp, "%u\n", &(model->num_markers)) != 1) 00407 { 00408 snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file"); 00409 return JWSC_EIO(str); 00410 } 00411 00412 if ((e = read_matrix_with_header_fp_d(&(model->mu), fp)) != NULL) 00413 { 00414 snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file"); 00415 return JWSC_EIO(str); 00416 } 00417 00418 if ((e = read_matrix_with_header_fp_d(&(model->sigma), fp)) != NULL) 00419 { 00420 snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file"); 00421 return JWSC_EIO(str); 00422 } 00423 00424 if ((e = read_vector_with_header_fp_d(&(model->priors), fp)) != NULL) 00425 { 00426 snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file"); 00427 return JWSC_EIO(str); 00428 } 00429 00430 if (fclose(fp) != 0) 00431 { 00432 snprintf(str, slen, "%s: %s", fname, strerror(errno)); 00433 return JWSC_EIO(str); 00434 } 00435 00436 return NULL; 00437 } 00438 00439 00446 Error* write_nb_gauss_model(NB_gauss_model* model, const char* fname) 00447 { 00448 FILE* fp; 00449 int slen = 256; 00450 char str[slen]; 00451 00452 if ((fp = fopen(fname, "w")) == NULL) 00453 { 00454 snprintf(str, slen, "%s: %s", fname, strerror(errno)); 00455 return JWSC_EIO(str); 00456 } 00457 00458 fprintf(fp, "%u\n", model->num_markers); 00459 00460 write_matrix_with_header_fp_d(model->mu, fp); 00461 write_matrix_with_header_fp_d(model->sigma, fp); 00462 write_vector_with_header_fp_d(model->priors, fp); 00463 00464 if (fclose(fp) != 0) 00465 { 00466 snprintf(str, slen, "%s: %s", fname, strerror(errno)); 00467 return JWSC_EIO(str); 00468 } 00469 00470 return NULL; 00471 } 00472 00473 00475 void free_nb_gauss_model(NB_gauss_model* model) 00476 { 00477 if (!model) 00478 return; 00479 00480 free_matrix_d(model->mu); 00481 free_matrix_d(model->sigma); 00482 free_vector_d(model->priors); 00483 free(model); 00484 } 00485 00486 00488 static Error* read_nb_gauss_xml_doc 00489 ( 00490 xmlDoc** xml_doc_out, 00491 const char* xml_fname, 00492 const char* dtd_fname 00493 ) 00494 { 00495 xmlParserCtxt* xml_parse_ctxt; 00496 xmlValidCtxt* xml_valid_ctxt; 00497 xmlDtd* xml_dtd; 00498 int slen = 256; 00499 char str[slen]; 00500 00501 assert(xml_parse_ctxt = xmlNewParserCtxt()); 00502 00503 if (!(*xml_doc_out = xmlCtxtReadFile(xml_parse_ctxt, xml_fname, NULL, 0))) 00504 { 00505 snprintf(str, slen, "%s: %s", xml_fname, "Could not parse file"); 00506 return JWSC_EIO(str); 00507 } 00508 00509 xmlFreeParserCtxt(xml_parse_ctxt); 00510 00511 if (dtd_fname) 00512 { 00513 assert(xml_valid_ctxt = xmlNewValidCtxt()); 00514 00515 if (!(xml_dtd = xmlParseDTD(NULL, (xmlChar*)dtd_fname))) 00516 { 00517 snprintf(str, slen, "%s: %s", dtd_fname, "Could not parse DTD"); 00518 return JWSC_EIO(str); 00519 } 00520 00521 if (!xmlValidateDtd(xml_valid_ctxt, *xml_doc_out, xml_dtd)) 00522 { 00523 snprintf(str, slen, "%s: %s", xml_fname, "XML file not valid"); 00524 return JWSC_EIO(str); 00525 } 00526 00527 xmlFreeValidCtxt(xml_valid_ctxt); 00528 xmlFreeDtd(xml_dtd); 00529 } 00530 00531 return NULL; 00532 } 00533 00534 00541 static Error* create_nb_gauss_model_tree_from_xml_node 00542 ( 00543 NB_gauss_model_tree** tree_out, 00544 const NB_gauss_model_tree* parent, 00545 uint32_t parent_label, 00546 xmlNode* xml_node 00547 ) 00548 { 00549 NB_gauss_model_tree* tree; 00550 uint32_t i, j; 00551 uint32_t label; 00552 uint32_t altlabel; 00553 uint32_t num_altlabels; 00554 double p; 00555 xmlAttr* attr; 00556 xmlNode* it; 00557 xmlNode* itt; 00558 Error* err; 00559 const char* fname; 00560 00561 if (*tree_out) 00562 { 00563 free_nb_gauss_model_tree(*tree_out); 00564 } 00565 00566 assert(*tree_out = malloc(sizeof(NB_gauss_model_tree))); 00567 tree = *tree_out; 00568 00569 tree->parent = parent; 00570 tree->parent_label = parent_label; 00571 tree->num_groups = 0; 00572 tree->subtrees = NULL; 00573 tree->labels = NULL; 00574 tree->altlabels = NULL; 00575 tree->priors = NULL; 00576 tree->model = NULL; 00577 tree->model_fname = NULL; 00578 00579 assert(XMLStrEqual(xml_node->name, "model")); 00580 00581 for (it = xml_node->children; it; it = it->next) 00582 { 00583 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "group")) 00584 { 00585 tree->num_groups++; 00586 } 00587 } 00588 00589 assert(tree->subtrees = calloc(tree->num_groups, sizeof(void*))); 00590 assert(tree->altlabels = calloc(tree->num_groups, sizeof(void*))); 00591 create_vector_u32(&(tree->labels), tree->num_groups); 00592 00593 for (attr = xml_node->properties; attr; attr = attr->next) 00594 { 00595 if (XMLStrEqual(attr->name, "priors") && 00596 XMLStrEqual(attr->children->content, "true")) 00597 { 00598 create_init_vector_d(&(tree->priors), get_num_haplo_groups(), 1.0); 00599 break; 00600 } 00601 } 00602 00603 i = 0; 00604 for (it = xml_node->children; it; it = it->next) 00605 { 00606 j = 0; 00607 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "file")) 00608 { 00609 fname = (const char*) it->children->content; 00610 tree->model_fname = malloc((strlen(fname)+1) * sizeof(char)); 00611 strcpy(tree->model_fname, fname); 00612 } 00613 else if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "group")) 00614 { 00615 num_altlabels = 0; 00616 for (itt = it->children; itt; itt = itt->next) 00617 { 00618 if (itt->type == XML_ELEMENT_NODE && 00619 XMLStrEqual(itt->name, "altlabel")) 00620 { 00621 num_altlabels++; 00622 } 00623 } 00624 if (num_altlabels) 00625 { 00626 create_vector_u32(&(tree->altlabels[ i ]), num_altlabels); 00627 } 00628 00629 for (itt = it->children; itt; itt = itt->next) 00630 { 00631 if (itt->type == XML_ELEMENT_NODE && 00632 XMLStrEqual(itt->name, "label")) 00633 { 00634 if ((err = lookup_haplo_group_index_from_label(&label, 00635 (const char*) itt->children->content))) 00636 { 00637 return err; 00638 } 00639 tree->labels->elts[ i ] = label; 00640 } 00641 else if (itt->type == XML_ELEMENT_NODE && 00642 XMLStrEqual(itt->name, "altlabel")) 00643 { 00644 if ((err = lookup_haplo_group_index_from_label(&altlabel, 00645 (const char*) itt->children->content))) 00646 { 00647 return err; 00648 } 00649 tree->altlabels[ i ]->elts[ j++ ] = altlabel; 00650 } 00651 else if (tree->priors && itt->type == XML_ELEMENT_NODE && 00652 XMLStrEqual(itt->name, "prior")) 00653 { 00654 if (sscanf((char*)itt->children->content, "%lf", &p) != 1 || 00655 p < 0 || p > 1) 00656 { 00657 return JWSC_EARG("Invalid prior"); 00658 } 00659 tree->priors->elts[ i ] = p; 00660 } 00661 else if (itt->type == XML_ELEMENT_NODE && 00662 XMLStrEqual(itt->name, "model")) 00663 { 00664 if ((err = create_nb_gauss_model_tree_from_xml_node( 00665 &(tree->subtrees[ i ]), tree, label, itt))) 00666 { 00667 return err; 00668 } 00669 } 00670 } 00671 i++; 00672 } 00673 } 00674 assert(i == tree->num_groups); 00675 00676 return NULL; 00677 } 00678 00679 00684 static Error* create_model_training_data 00685 ( 00686 Vector_u32** train_labels_out, 00687 Matrix_i32** train_markers_out, 00688 const Vector_u32* data_labels, 00689 const Matrix_i32* data_markers, 00690 const Vector_u32* model_labels, 00691 Vector_u32*const* model_altlabels 00692 ) 00693 { 00694 uint8_t b; 00695 uint32_t i, j, k; 00696 uint32_t n; 00697 Vector_u32* train_labels = NULL; 00698 Matrix_i32* train_markers = NULL; 00699 00700 copy_vector_u32(&train_labels, data_labels); 00701 copy_matrix_i32(&train_markers, data_markers); 00702 00703 n = 0; 00704 for (i = 0; i < data_labels->num_elts; i++) 00705 { 00706 for (j = 0; j < model_labels->num_elts; j++) 00707 { 00708 if (is_ancestor(data_labels->elts[ i ], model_labels->elts[ j ])) 00709 { 00710 train_labels->elts[ n ] = model_labels->elts[ j ]; 00711 copy_matrix_block_into_matrix_i32(train_markers, n, 0, 00712 data_markers, i, 0, 1, data_markers->num_cols); 00713 n++; 00714 break; 00715 } 00716 else if (model_altlabels[ j ]) 00717 { 00718 for (k = 0; k < model_altlabels[ j ]->num_elts; k++) 00719 { 00720 if ((b = is_ancestor(data_labels->elts[ i ], 00721 model_altlabels[ j ]->elts[ k ]))) 00722 { 00723 train_labels->elts[ n ] = model_labels->elts[ j ]; 00724 copy_matrix_block_into_matrix_i32(train_markers, n, 0, 00725 data_markers, i, 0, 1, data_markers->num_cols); 00726 n++; 00727 break; 00728 } 00729 } 00730 if (b) 00731 { 00732 break; 00733 } 00734 } 00735 } 00736 } 00737 00738 if (!n) 00739 { 00740 return JWSC_EARG("No data for model"); 00741 } 00742 00743 copy_vector_section_u32(train_labels_out, train_labels, 0, n); 00744 copy_matrix_block_i32(train_markers_out, train_markers, 0, 0, n, 00745 train_markers->num_cols); 00746 00747 free_vector_u32(train_labels); 00748 free_matrix_i32(train_markers); 00749 00750 return NULL; 00751 } 00752 00753 00755 static Error* train_nb_gauss_model_node 00756 ( 00757 NB_gauss_model_node* node, 00758 const Vector_u32* labels, 00759 const Matrix_i32* markers 00760 ) 00761 { 00762 uint32_t i; 00763 Vector_u32* node_labels = NULL; 00764 Matrix_i32* node_markers = NULL; 00765 Error* err; 00766 int slen = 256; 00767 char str[slen]; 00768 00769 if ((err = create_model_training_data(&node_labels, &node_markers, labels, 00770 markers, node->labels, node->altlabels))) 00771 { 00772 snprintf(str, slen, "%s: %s", node->model_fname, err->msg); 00773 return JWSC_EARG(str); 00774 } 00775 00776 train_nb_gauss_model(&(node->model), node_labels, node_markers, 00777 node->priors); 00778 00779 for (i = 0; i < node->num_groups; i++) 00780 { 00781 if (node->subtrees[ i ]) 00782 { 00783 if ((err = train_nb_gauss_model_node(node->subtrees[ i ], labels, 00784 markers))) 00785 { 00786 return err; 00787 } 00788 } 00789 } 00790 00791 free_vector_u32(node_labels); 00792 free_matrix_i32(node_markers); 00793 00794 return NULL; 00795 } 00796 00797 00805 Error* train_nb_gauss_model_tree 00806 ( 00807 NB_gauss_model_tree** tree_out, 00808 const Vector_u32* labels, 00809 const Matrix_i32* markers, 00810 const char* tree_xml_fname, 00811 const char* tree_dtd_fname 00812 ) 00813 { 00814 xmlDoc* xml_doc; 00815 xmlNode* xml_node; 00816 xmlNode* it; 00817 Error* err; 00818 int slen = 256; 00819 char str[slen]; 00820 00821 assert(tree_xml_fname); 00822 00823 if ((err = read_nb_gauss_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname))) 00824 { 00825 return err; 00826 } 00827 00828 for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next) 00829 { 00830 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model")) 00831 { 00832 xml_node = it; 00833 break; 00834 } 00835 } 00836 00837 if ((err = create_nb_gauss_model_tree_from_xml_node(tree_out, NULL, 0, 00838 xml_node))) 00839 { 00840 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 00841 return JWSC_EARG(str); 00842 } 00843 00844 if ((err = train_nb_gauss_model_node(*tree_out, labels, markers))) 00845 { 00846 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 00847 return JWSC_EARG(str); 00848 } 00849 00850 return NULL; 00851 } 00852 00853 00855 static Error* recursively_predict_label_in_model_tree 00856 ( 00857 uint32_t* label_out, 00858 double* confidence_out, 00859 const NB_gauss_model_tree* tree, 00860 const Vector_i32* markers_v, 00861 uint32_t order 00862 ) 00863 { 00864 uint32_t subtree_label; 00865 uint32_t i; 00866 double confidence; 00867 Error* err; 00868 00869 if ((err = predict_label_with_nb_gauss_model(label_out, &confidence, 00870 markers_v, tree->model, order)) != NULL) 00871 { 00872 return err; 00873 } 00874 *confidence_out *= confidence; 00875 00876 for (i = 0; i < tree->num_groups; i++) 00877 { 00878 if (tree->subtrees[ i ]) 00879 { 00880 subtree_label = tree->subtrees[ i ]->parent_label; 00881 00882 if (subtree_label == *label_out) 00883 { 00884 if ((err = recursively_predict_label_in_model_tree(label_out, 00885 confidence_out, tree->subtrees[ i ], markers_v, 00886 order)) != NULL) 00887 { 00888 return err; 00889 } 00890 } 00891 } 00892 } 00893 00894 return NULL; 00895 } 00896 00897 00913 Error* predict_labels_with_nb_gauss_model_tree 00914 ( 00915 Vector_u32** labels_out, 00916 Vector_d** confidence_out, 00917 const Matrix_i32* markers, 00918 const NB_gauss_model_tree* tree, 00919 uint32_t order 00920 ) 00921 { 00922 uint32_t s, num_samples; 00923 uint32_t m, num_markers; 00924 00925 Vector_u32* labels; 00926 Vector_d* confidence; 00927 Vector_i32* markers_v; 00928 Error* err; 00929 00930 num_samples = markers->num_rows; 00931 num_markers = markers->num_cols; 00932 00933 create_vector_u32(labels_out, num_samples); 00934 labels = *labels_out; 00935 00936 create_init_vector_d(confidence_out, num_samples, 1.0); 00937 confidence = *confidence_out; 00938 00939 markers_v = NULL; 00940 create_vector_i32(&markers_v, num_markers); 00941 00942 for (s = 0; s < num_samples; s++) 00943 { 00944 for (m = 0; m < num_markers; m++) 00945 { 00946 markers_v->elts[ m ]= markers->elts[ s ][ m ]; 00947 } 00948 00949 if ((err = recursively_predict_label_in_model_tree( 00950 &(labels->elts[ s ]), 00951 &(confidence->elts[ s ]), tree, markers_v, order)) 00952 != NULL) 00953 { 00954 return err; 00955 } 00956 } 00957 00958 free_vector_i32(markers_v); 00959 00960 return NULL; 00961 } 00962 00963 00965 static Error* read_nb_gauss_model_node 00966 ( 00967 NB_gauss_model_node* node, 00968 const char* model_dirname 00969 ) 00970 { 00971 uint32_t i; 00972 char buf[1024] = {0}; 00973 Error* err; 00974 00975 snprintf(buf, 1024, "%s/%s", model_dirname, node->model_fname); 00976 if ((err = read_nb_gauss_model(&(node->model), buf))) 00977 { 00978 return err; 00979 } 00980 00981 for (i = 0; i < node->num_groups; i++) 00982 { 00983 if (node->subtrees[ i ]) 00984 { 00985 if ((err = read_nb_gauss_model_node(node->subtrees[ i ], 00986 model_dirname))) 00987 { 00988 return err; 00989 } 00990 } 00991 } 00992 00993 return NULL; 00994 } 00995 00996 01003 Error* read_nb_gauss_model_tree 01004 ( 01005 NB_gauss_model_tree** tree_out, 01006 const char* tree_xml_fname, 01007 const char* tree_dtd_fname, 01008 const char* model_dirname 01009 ) 01010 { 01011 xmlDoc* xml_doc; 01012 xmlNode* xml_node; 01013 xmlNode* it; 01014 Error* err; 01015 int slen = 256; 01016 char str[slen]; 01017 01018 assert(tree_xml_fname && model_dirname); 01019 01020 if ((err = read_nb_gauss_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname))) 01021 { 01022 return err; 01023 } 01024 01025 for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next) 01026 { 01027 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model")) 01028 { 01029 xml_node = it; 01030 break; 01031 } 01032 } 01033 01034 if ((err = create_nb_gauss_model_tree_from_xml_node(tree_out, NULL, 0, 01035 xml_node))) 01036 { 01037 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 01038 return JWSC_EARG(str); 01039 } 01040 01041 if ((err = read_nb_gauss_model_node(*tree_out, model_dirname))) 01042 { 01043 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 01044 return JWSC_EARG(str); 01045 } 01046 01047 return NULL; 01048 } 01049 01050 01055 Error* write_nb_gauss_model_tree 01056 ( 01057 const NB_gauss_model_tree* tree, 01058 const char* model_dirname 01059 ) 01060 { 01061 uint32_t i; 01062 char buf[1024] = {0}; 01063 Error* err; 01064 int slen = 256; 01065 char str[slen]; 01066 01067 snprintf(buf, 1024, "%s/%s", model_dirname, tree->model_fname); 01068 if ((err = write_nb_gauss_model(tree->model, buf))) 01069 { 01070 return err; 01071 } 01072 01073 for (i = 0; i < tree->num_groups; i++) 01074 { 01075 if (tree->subtrees[ i ]) 01076 { 01077 if ((err = write_nb_gauss_model_tree(tree->subtrees[ i ], 01078 model_dirname))) 01079 { 01080 snprintf(str, slen, "%s: %s", tree->model_fname, err->msg); 01081 return JWSC_EIO(str); 01082 } 01083 } 01084 } 01085 01086 return NULL; 01087 } 01088 01089 01091 void free_nb_gauss_model_tree(NB_gauss_model_tree* tree) 01092 { 01093 uint32_t i; 01094 01095 if (!tree) 01096 return; 01097 01098 for (i = 0; i < tree->num_groups; i++) 01099 { 01100 free_nb_gauss_model_tree(tree->subtrees[ i ]); 01101 free_vector_u32(tree->altlabels[ i ]); 01102 } 01103 01104 free(tree->subtrees); 01105 free(tree->altlabels); 01106 free_vector_u32(tree->labels); 01107 free_vector_d(tree->priors); 01108 free_nb_gauss_model(tree->model); 01109 free(tree->model_fname); 01110 free(tree); 01111 }