Haplo Prediction
predict haplogroups
|
00001 /* 00002 * This work is licensed under a Creative Commons 00003 * Attribution-Noncommercial-Share Alike 3.0 United States License. 00004 * 00005 * http://creativecommons.org/licenses/by-nc-sa/3.0/us/ 00006 * 00007 * You are free: 00008 * 00009 * to Share - to copy, distribute, display, and perform the work 00010 * to Remix - to make derivative works 00011 * 00012 * Under the following conditions: 00013 * 00014 * Attribution. You must attribute the work in the manner specified by the 00015 * author or licensor (but not in any way that suggests that they endorse you 00016 * or your use of the work). 00017 * 00018 * Noncommercial. You may not use this work for commercial purposes. 00019 * 00020 * Share Alike. If you alter, transform, or build upon this work, you may 00021 * distribute the resulting work only under the same or similar license to 00022 * this one. 00023 * 00024 * For any reuse or distribution, you must make clear to others the license 00025 * terms of this work. The best way to do this is by including this header. 00026 * 00027 * Any of the above conditions can be waived if you get permission from the 00028 * copyright holder. 00029 * 00030 * Apart from the remix rights granted under this license, nothing in this 00031 * license impairs or restricts the author's moral rights. 00032 */ 00033 00034 00046 #include <config.h> 00047 00048 #include <stdlib.h> 00049 #include <stdio.h> 00050 #include <inttypes.h> 00051 #include <assert.h> 00052 #include <string.h> 00053 #include <errno.h> 00054 #include <math.h> 00055 00056 #include <libxml/tree.h> 00057 #include <libxml/parser.h> 00058 #include <libxml/valid.h> 00059 00060 #ifdef HAPLO_HAVE_DMALLOC 00061 #include <dmalloc.h> 00062 #endif 00063 00064 #include <jwsc/base/error.h> 00065 #include <jwsc/base/limits.h> 00066 #include <jwsc/base/file_io.h> 00067 #include <jwsc/vector/vector.h> 00068 #include <jwsc/vector/vector_io.h> 00069 #include <jwsc/vector/vector_math.h> 00070 #include <jwsc/matrix/matrix.h> 00071 #include <jwsc/matrix/matrix_io.h> 00072 #include <jwsc/matblock/matblock.h> 00073 #include <jwsc/matblock/matblock_io.h> 00074 #include <jwsc/stat/gmm.h> 00075 00076 #include "xml.h" 00077 #include "haplo_groups.h" 00078 #include "nb_gmm.h" 00079 00080 00091 void train_nb_gmm_model 00092 ( 00093 NB_gmm_model** model_out, 00094 const Vector_u32* labels, 00095 const Matrix_i32* markers, 00096 const Vector_d* label_priors, 00097 uint32_t num_components 00098 ) 00099 { 00100 uint32_t l, num_labels; 00101 uint32_t s, num_samples; 00102 uint32_t m, num_markers; 00103 uint32_t n, N; 00104 00105 Matrix_d*** mu; 00106 Matblock_d*** sigma; 00107 Vector_d*** pi; 00108 Vector_d* priors; 00109 Matrix_d* markers_d = NULL; 00110 00111 assert(markers->num_rows == labels->num_elts); 00112 00113 num_markers = markers->num_cols; 00114 num_samples = markers->num_rows; 00115 num_labels = get_num_haplo_groups(); 00116 00117 if (*model_out != NULL) 00118 { 00119 free_nb_gmm_model(*model_out); 00120 } 00121 00122 assert(mu = malloc(num_labels*sizeof(Matrix_d*))); 00123 assert(sigma = malloc(num_labels*sizeof(Matblock_d*))); 00124 assert(pi = malloc(num_labels*sizeof(Vector_d*))); 00125 priors = NULL; 00126 00127 for (l = 0; l < num_labels; l++) 00128 { 00129 mu [ l ] = NULL; 00130 sigma[ l ] = NULL; 00131 pi [ l ] = NULL; 00132 00133 N = 0; 00134 for (s = 0; s < num_samples; s++) 00135 { 00136 if (labels->elts[ s ] == l) 00137 { 00138 N++; 00139 } 00140 } 00141 00142 // TODO we should probably handle the case of N=1 by creating a 00143 // distribution with the sample as the mean and a very small diagonal 00144 // covariance matrix. 00145 if (N < num_components*2) 00146 continue; 00147 00148 assert(mu[ l ] = malloc(num_markers*sizeof(void*))); 00149 assert(sigma[ l ] = malloc(num_markers*sizeof(void*))); 00150 assert(pi[ l ] = malloc(num_markers*sizeof(void*))); 00151 00152 create_matrix_d(&markers_d, N, 1); 00153 00154 for (m = 0; m < num_markers; m++) 00155 { 00156 n = 0; 00157 for (s = 0; s < num_samples; s++) 00158 { 00159 if (labels->elts[ s ] == l) 00160 { 00161 markers_d->elts[ n ][ 0 ] = markers->elts[ s ][ m ]; 00162 n++; 00163 } 00164 } 00165 00166 mu[ l ][ m ] = NULL; 00167 sigma[ l ][ m ] = NULL; 00168 pi[ l ][ m ] = NULL; 00169 00170 train_gmm_d(&(mu[l][m]), &(sigma[l][m]), &(pi[l][m]), NULL, 00171 markers_d, num_components); 00172 } 00173 } 00174 00175 free_matrix_d(markers_d); 00176 00177 if (label_priors == NULL) 00178 { 00179 create_zero_vector_d(&priors, num_labels); 00180 00181 for (s = 0; s < num_samples; s++) 00182 { 00183 l = labels->elts[ s ]; 00184 priors->elts[ l ]++; 00185 } 00186 normalize_vector_sum_d(&priors, priors); 00187 } 00188 else 00189 { 00190 normalize_vector_sum_d(&priors, label_priors); 00191 } 00192 00193 assert(*model_out = malloc(sizeof(NB_gmm_model))); 00194 (*model_out)->num_markers = num_markers; 00195 (*model_out)->num_components = num_components; 00196 (*model_out)->mu = mu; 00197 (*model_out)->sigma = sigma; 00198 (*model_out)->pi = pi; 00199 (*model_out)->priors = priors; 00200 } 00201 00202 00214 Error* predict_label_with_nb_gmm_model 00215 ( 00216 uint32_t* label_out, 00217 double* confidence_out, 00218 const Vector_i32* markers, 00219 const NB_gmm_model* model, 00220 uint32_t order 00221 ) 00222 { 00223 uint32_t num_markers; 00224 uint32_t marker; 00225 uint32_t num_labels; 00226 uint32_t label; 00227 uint32_t best_label; 00228 uint32_t i; 00229 double log_map; 00230 double log_ll; 00231 double log_prior; 00232 double log_posterior; 00233 double posterior_sum; 00234 00235 Vector_u32* best_labels = NULL; 00236 Vector_d* best_labels_conf = NULL; 00237 Matrix_d* marker_val = NULL; 00238 00239 const Matrix_d* mu; 00240 const Matblock_d* sigma; 00241 const Vector_d* pi; 00242 00243 num_labels = get_num_haplo_groups(); 00244 num_markers = markers->num_elts; 00245 00246 log_map = log(JWSC_MIN_LOG_ARG); 00247 posterior_sum = 0; 00248 00249 create_init_vector_u32(&best_labels, num_labels, 0); 00250 create_init_vector_d(&best_labels_conf, num_labels, 0); 00251 00252 create_matrix_d(&marker_val, 1, 1); 00253 00254 for (label = 0; label < num_labels; label++) 00255 { 00256 if (!model->mu[ label ]) 00257 continue; 00258 00259 log_ll = 0; 00260 00261 for (marker = 0; marker < num_markers; marker++) 00262 { 00263 marker_val->elts[0][0] = markers->elts[ marker ]; 00264 00265 /* Ignore non-positive marker values. */ 00266 if (marker_val > 0) 00267 { 00268 mu = model->mu[ label ][ marker ]; 00269 sigma = model->sigma[ label ][ marker ]; 00270 pi = model->pi[ label ][ marker ]; 00271 00272 log_ll += gmm_log_likelihood_d(mu, sigma, pi, marker_val); 00273 } 00274 } 00275 00276 if (model->priors->elts[ label ] < JWSC_MIN_LOG_ARG) 00277 { 00278 log_prior = log(JWSC_MIN_LOG_ARG); 00279 } 00280 else 00281 { 00282 log_prior = log(model->priors->elts[ label ]); 00283 } 00284 00285 log_posterior = log_ll + log_prior; 00286 posterior_sum += exp(log_posterior); 00287 00288 if (log_posterior > log_map) 00289 { 00290 log_map = log_posterior; 00291 best_label = label; 00292 for (i = num_labels-1; i > 0; i--) 00293 { 00294 best_labels->elts[ i ] = best_labels->elts[ i - 1]; 00295 best_labels_conf->elts[ i ] = best_labels_conf->elts[ i - 1]; 00296 } 00297 best_labels->elts[ 0 ] = best_label; 00298 best_labels_conf->elts[ 0 ] = log_map; 00299 } 00300 } 00301 00302 assert(order < best_labels->num_elts); 00303 00304 *label_out = best_labels->elts[ order ]; 00305 00306 if (posterior_sum > 0) 00307 { 00308 assert(finite(*confidence_out = exp(best_labels_conf->elts[order]) / 00309 (double)posterior_sum)); 00310 } 00311 else 00312 { 00313 *confidence_out = 0; 00314 } 00315 00316 free_vector_u32(best_labels); 00317 free_vector_d(best_labels_conf); 00318 free_matrix_d(marker_val); 00319 00320 return NULL; 00321 } 00322 00323 00337 Error* predict_labels_with_nb_gmm_model 00338 ( 00339 Vector_u32** labels_out, 00340 Vector_d** confidence_out, 00341 const Matrix_i32* markers, 00342 const NB_gmm_model* model 00343 ) 00344 { 00345 uint32_t num_samples; 00346 uint32_t sample; 00347 uint32_t num_markers; 00348 uint32_t marker; 00349 00350 Vector_u32* labels; 00351 Vector_d* confidence; 00352 Vector_i32* markers_v; 00353 Error* e; 00354 00355 num_samples = markers->num_rows; 00356 num_markers = markers->num_cols; 00357 00358 create_vector_u32(labels_out, num_samples); 00359 labels = *labels_out; 00360 00361 create_vector_d(confidence_out, num_samples); 00362 confidence = *confidence_out; 00363 00364 markers_v = NULL; 00365 create_vector_i32(&markers_v, num_markers); 00366 00367 for (sample = 0; sample < num_samples; sample++) 00368 { 00369 for (marker = 0; marker < num_markers; marker++) 00370 { 00371 markers_v->elts[ marker ] = markers->elts[ sample ][marker]; 00372 } 00373 00374 if ((e = predict_label_with_nb_gmm_model(&(labels->elts[ sample ]), 00375 &(confidence->elts[ sample ]), markers_v, model, 0)) != NULL) 00376 { 00377 return e; 00378 } 00379 } 00380 00381 free_vector_i32(markers_v); 00382 00383 return NULL; 00384 } 00385 00386 00395 Error* read_nb_gmm_model(NB_gmm_model** model_out, const char* fname) 00396 { 00397 uint32_t num_labels; 00398 uint32_t i, j; 00399 FILE* fp; 00400 NB_gmm_model* model; 00401 Error* e; 00402 int slen = 256; 00403 char str[slen]; 00404 00405 num_labels = get_num_haplo_groups(); 00406 00407 if ((fp = fopen(fname, "r")) == NULL) 00408 { 00409 snprintf(str, slen, "%s: %s", fname, strerror(errno)); 00410 return JWSC_EARG(str); 00411 } 00412 00413 if (*model_out == NULL) 00414 { 00415 assert((*model_out = malloc(sizeof(NB_gmm_model))) != NULL); 00416 model = *model_out; 00417 model->priors = NULL; 00418 } 00419 00420 if (fscanf(fp, "%u\n%u\n", &(model->num_markers), 00421 &(model->num_components)) != 2) 00422 { 00423 snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file"); 00424 return JWSC_EARG(str); 00425 } 00426 00427 assert(model->mu = malloc(num_labels*sizeof(void*))); 00428 assert(model->sigma = malloc(num_labels*sizeof(void*))); 00429 assert(model->pi = malloc(num_labels*sizeof(void*))); 00430 00431 for (i = 0; i < num_labels; i++) 00432 { 00433 if (fscanf(fp, "%u", (uint32_t*)(&(model->mu[ i ]))) != 1) 00434 { 00435 return JWSC_EARG("Improperly formatted model file"); 00436 } 00437 } 00438 00439 for (i = 0; i < num_labels; i++) 00440 { 00441 model->sigma[ i ] = NULL; 00442 model->pi[ i ] = NULL; 00443 00444 if (model->mu[ i ]) 00445 { 00446 assert(model->mu[i] = malloc(model->num_markers*sizeof(void*))); 00447 assert(model->sigma[i] = malloc(model->num_markers*sizeof(void*))); 00448 assert(model->pi[i] = malloc(model->num_markers*sizeof(void*))); 00449 00450 for (j = 0; j < model->num_markers; j++) 00451 { 00452 model->mu[ i ][ j ] = NULL; 00453 model->sigma[ i ][ j ] = NULL; 00454 model->pi[ i ][ j ] = NULL; 00455 00456 if ((e = read_matrix_with_header_fp_d( 00457 &(model->mu[i][j]), fp)) || 00458 (e = read_matblock_with_header_fp_d( 00459 &(model->sigma[i][j]), fp)) || 00460 (e = read_vector_with_header_fp_d( 00461 &(model->pi[i][j]), fp))) 00462 { 00463 return JWSC_EARG("Improperly formatted model file"); 00464 } 00465 } 00466 } 00467 } 00468 00469 if ((e = read_vector_with_header_fp_d(&(model->priors), fp)) != NULL) 00470 { 00471 snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file"); 00472 return JWSC_EARG(str); 00473 } 00474 00475 if (fclose(fp) != 0) 00476 { 00477 snprintf(str, slen, "%s: %s", fname, strerror(errno)); 00478 return JWSC_EARG(str); 00479 } 00480 00481 return NULL; 00482 } 00483 00484 00491 Error* write_nb_gmm_model(NB_gmm_model* model, const char* fname) 00492 { 00493 uint32_t num_labels; 00494 uint32_t i, j; 00495 FILE* fp; 00496 int slen = 256; 00497 char str[slen]; 00498 00499 num_labels = get_num_haplo_groups(); 00500 00501 if ((fp = fopen(fname, "w")) == NULL) 00502 { 00503 snprintf(str, slen, "%s: %s", fname, strerror(errno)); 00504 return JWSC_EARG(str); 00505 } 00506 00507 fprintf(fp, "%u\n", model->num_markers); 00508 fprintf(fp, "%u\n", model->num_components); 00509 00510 for (i = 0; i < num_labels; i++) 00511 { 00512 fprintf(fp, "%u ", (model->mu[ i ] != NULL)); 00513 } 00514 fprintf(fp, "\n"); 00515 00516 for (i = 0; i < num_labels; i++) 00517 { 00518 if (model->mu[ i ]) 00519 { 00520 for (j = 0; j < model->num_markers; j++) 00521 { 00522 write_matrix_with_header_fp_d(model->mu[ i ][ j ], fp); 00523 write_matblock_with_header_fp_d(model->sigma[ i ][ j ], fp); 00524 write_vector_with_header_fp_d(model->pi[ i ][ j ], fp); 00525 } 00526 } 00527 } 00528 00529 write_vector_with_header_fp_d(model->priors, fp); 00530 00531 if (fclose(fp) != 0) 00532 { 00533 snprintf(str, slen, "%s: %s", fname, strerror(errno)); 00534 return JWSC_EARG(str); 00535 } 00536 00537 return NULL; 00538 } 00539 00540 00542 void free_nb_gmm_model(NB_gmm_model* model) 00543 { 00544 uint32_t i, j; 00545 uint32_t num_labels; 00546 00547 if (!model) 00548 return; 00549 00550 num_labels = get_num_haplo_groups(); 00551 00552 for (i = 0; i < num_labels; i++) 00553 { 00554 if (model->mu[ i ]) 00555 { 00556 for (j = 0; j < model->num_markers; j++) 00557 { 00558 free_matrix_d(model->mu[ i ][ j ]); 00559 free_matblock_d(model->sigma[ i ][ j ]); 00560 free_vector_d(model->pi[ i ][ j ]); 00561 } 00562 free(model->mu[ i ]); 00563 free(model->sigma[ i ]); 00564 free(model->pi[ i ]); 00565 } 00566 } 00567 free(model->mu); 00568 free(model->sigma); 00569 free(model->pi); 00570 free_vector_d(model->priors); 00571 free(model); 00572 } 00573 00574 00576 static Error* read_nb_gmm_xml_doc 00577 ( 00578 xmlDoc** xml_doc_out, 00579 const char* xml_fname, 00580 const char* dtd_fname 00581 ) 00582 { 00583 xmlParserCtxt* xml_parse_ctxt; 00584 xmlValidCtxt* xml_valid_ctxt; 00585 xmlDtd* xml_dtd; 00586 int slen = 256; 00587 char str[slen]; 00588 00589 assert(xml_parse_ctxt = xmlNewParserCtxt()); 00590 00591 if (!(*xml_doc_out = xmlCtxtReadFile(xml_parse_ctxt, xml_fname, NULL, 0))) 00592 { 00593 snprintf(str, slen, "%s: %s", xml_fname, "Could not parse file"); 00594 return JWSC_EARG(str); 00595 } 00596 00597 xmlFreeParserCtxt(xml_parse_ctxt); 00598 00599 if (dtd_fname) 00600 { 00601 assert(xml_valid_ctxt = xmlNewValidCtxt()); 00602 00603 if (!(xml_dtd = xmlParseDTD(NULL, (xmlChar*)dtd_fname))) 00604 { 00605 snprintf(str, slen, "%s: %s", dtd_fname, "Could not parse DTD"); 00606 return JWSC_EARG(str); 00607 } 00608 00609 if (!xmlValidateDtd(xml_valid_ctxt, *xml_doc_out, xml_dtd)) 00610 { 00611 snprintf(str, slen, "%s: %s", xml_fname, "XML file not valid"); 00612 return JWSC_EARG(str); 00613 } 00614 00615 xmlFreeValidCtxt(xml_valid_ctxt); 00616 xmlFreeDtd(xml_dtd); 00617 } 00618 00619 return NULL; 00620 } 00621 00622 00629 static Error* create_nb_gmm_model_tree_from_xml_node 00630 ( 00631 NB_gmm_model_tree** tree_out, 00632 const NB_gmm_model_tree* parent, 00633 uint32_t parent_label, 00634 xmlNode* xml_node 00635 ) 00636 { 00637 NB_gmm_model_tree* tree; 00638 uint32_t i, j; 00639 uint32_t label; 00640 uint32_t altlabel; 00641 uint32_t num_altlabels; 00642 double p; 00643 xmlAttr* attr; 00644 xmlNode* it; 00645 xmlNode* itt; 00646 Error* err; 00647 const char* fname; 00648 00649 if (*tree_out) 00650 { 00651 free_nb_gmm_model_tree(*tree_out); 00652 } 00653 00654 assert(*tree_out = malloc(sizeof(NB_gmm_model_tree))); 00655 tree = *tree_out; 00656 00657 tree->parent = parent; 00658 tree->parent_label = parent_label; 00659 tree->num_groups = 0; 00660 tree->subtrees = NULL; 00661 tree->labels = NULL; 00662 tree->altlabels = NULL; 00663 tree->priors = NULL; 00664 tree->model = NULL; 00665 tree->model_fname = NULL; 00666 00667 assert(XMLStrEqual(xml_node->name, "model")); 00668 00669 for (it = xml_node->children; it; it = it->next) 00670 { 00671 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "group")) 00672 { 00673 tree->num_groups++; 00674 } 00675 } 00676 00677 assert(tree->subtrees = calloc(tree->num_groups, sizeof(void*))); 00678 assert(tree->altlabels = calloc(tree->num_groups, sizeof(void*))); 00679 create_vector_u32(&(tree->labels), tree->num_groups); 00680 00681 for (attr = xml_node->properties; attr; attr = attr->next) 00682 { 00683 if (XMLStrEqual(attr->name, "priors") && 00684 XMLStrEqual(attr->children->content, "true")) 00685 { 00686 create_init_vector_d(&(tree->priors), get_num_haplo_groups(), 1.0); 00687 break; 00688 } 00689 } 00690 00691 i = 0; 00692 for (it = xml_node->children; it; it = it->next) 00693 { 00694 j = 0; 00695 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "file")) 00696 { 00697 fname = (const char*) it->children->content; 00698 tree->model_fname = malloc((strlen(fname)+1) * sizeof(char)); 00699 strcpy(tree->model_fname, fname); 00700 } 00701 else if (it->type == XML_ELEMENT_NODE && 00702 XMLStrEqual(it->name, "mixture-components")) 00703 { 00704 if (sscanf((char*)it->children->content, "%d", 00705 &(tree->num_components)) != 1 || 00706 tree->num_components < 1) 00707 { 00708 return JWSC_EARG("Invalid number of mixture components"); 00709 } 00710 } 00711 else if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "group")) 00712 { 00713 num_altlabels = 0; 00714 for (itt = it->children; itt; itt = itt->next) 00715 { 00716 if (itt->type == XML_ELEMENT_NODE && 00717 XMLStrEqual(itt->name, "altlabel")) 00718 { 00719 num_altlabels++; 00720 } 00721 } 00722 if (num_altlabels) 00723 { 00724 create_vector_u32(&(tree->altlabels[ i ]), num_altlabels); 00725 } 00726 00727 for (itt = it->children; itt; itt = itt->next) 00728 { 00729 if (itt->type == XML_ELEMENT_NODE && 00730 XMLStrEqual(itt->name, "label")) 00731 { 00732 if ((err = lookup_haplo_group_index_from_label(&label, 00733 (const char*) itt->children->content))) 00734 { 00735 return err; 00736 } 00737 tree->labels->elts[ i ] = label; 00738 } 00739 else if (itt->type == XML_ELEMENT_NODE && 00740 XMLStrEqual(itt->name, "altlabel")) 00741 { 00742 if ((err = lookup_haplo_group_index_from_label(&altlabel, 00743 (const char*) itt->children->content))) 00744 { 00745 return err; 00746 } 00747 tree->altlabels[ i ]->elts[ j++ ] = altlabel; 00748 } 00749 else if (tree->priors && itt->type == XML_ELEMENT_NODE && 00750 XMLStrEqual(itt->name, "prior")) 00751 { 00752 if (sscanf((char*)itt->children->content, "%lf", &p) != 1 || 00753 p < 0 || p > 1) 00754 { 00755 return JWSC_EARG("Invalid prior"); 00756 } 00757 tree->priors->elts[ i ] = p; 00758 } 00759 else if (itt->type == XML_ELEMENT_NODE && 00760 XMLStrEqual(itt->name, "model")) 00761 { 00762 if ((err = create_nb_gmm_model_tree_from_xml_node( 00763 &(tree->subtrees[ i ]), tree, label, itt))) 00764 { 00765 return err; 00766 } 00767 } 00768 } 00769 i++; 00770 } 00771 } 00772 assert(i == tree->num_groups); 00773 00774 return NULL; 00775 } 00776 00777 00782 static Error* create_model_training_data 00783 ( 00784 Vector_u32** train_labels_out, 00785 Matrix_i32** train_markers_out, 00786 const Vector_u32* data_labels, 00787 const Matrix_i32* data_markers, 00788 const Vector_u32* model_labels, 00789 Vector_u32*const* model_altlabels 00790 ) 00791 { 00792 uint8_t b; 00793 uint32_t i, j, k; 00794 uint32_t n; 00795 Vector_u32* train_labels = NULL; 00796 Matrix_i32* train_markers = NULL; 00797 00798 copy_vector_u32(&train_labels, data_labels); 00799 copy_matrix_i32(&train_markers, data_markers); 00800 00801 n = 0; 00802 for (i = 0; i < data_labels->num_elts; i++) 00803 { 00804 for (j = 0; j < model_labels->num_elts; j++) 00805 { 00806 if (is_ancestor(data_labels->elts[ i ], model_labels->elts[ j ])) 00807 { 00808 train_labels->elts[ n ] = model_labels->elts[ j ]; 00809 copy_matrix_block_into_matrix_i32(train_markers, n, 0, 00810 data_markers, i, 0, 1, data_markers->num_cols); 00811 n++; 00812 break; 00813 } 00814 else if (model_altlabels[ j ]) 00815 { 00816 for (k = 0; k < model_altlabels[ j ]->num_elts; k++) 00817 { 00818 if ((b = is_ancestor(data_labels->elts[ i ], 00819 model_altlabels[ j ]->elts[ k ]))) 00820 { 00821 train_labels->elts[ n ] = model_labels->elts[ j ]; 00822 copy_matrix_block_into_matrix_i32(train_markers, n, 0, 00823 data_markers, i, 0, 1, data_markers->num_cols); 00824 n++; 00825 break; 00826 } 00827 } 00828 if (b) 00829 { 00830 break; 00831 } 00832 } 00833 } 00834 } 00835 00836 if (!n) 00837 { 00838 return JWSC_EARG("No data for model"); 00839 } 00840 00841 copy_vector_section_u32(train_labels_out, train_labels, 0, n); 00842 copy_matrix_block_i32(train_markers_out, train_markers, 0, 0, n, 00843 train_markers->num_cols); 00844 00845 free_vector_u32(train_labels); 00846 free_matrix_i32(train_markers); 00847 00848 return NULL; 00849 } 00850 00851 00853 static Error* train_nb_gmm_model_node 00854 ( 00855 NB_gmm_model_node* node, 00856 const Vector_u32* labels, 00857 const Matrix_i32* markers 00858 ) 00859 { 00860 uint32_t i; 00861 Vector_u32* node_labels = NULL; 00862 Matrix_i32* node_markers = NULL; 00863 Error* err; 00864 int slen = 256; 00865 char str[slen]; 00866 00867 if ((err = create_model_training_data(&node_labels, &node_markers, labels, 00868 markers, node->labels, node->altlabels))) 00869 { 00870 snprintf(str, slen, "%s: %s", node->model_fname, err->msg); 00871 return JWSC_EARG(str); 00872 } 00873 00874 train_nb_gmm_model(&(node->model), node_labels, node_markers, 00875 node->priors, node->num_components); 00876 00877 for (i = 0; i < node->num_groups; i++) 00878 { 00879 if (node->subtrees[ i ]) 00880 { 00881 if ((err = train_nb_gmm_model_node(node->subtrees[ i ], labels, 00882 markers))) 00883 { 00884 return err; 00885 } 00886 } 00887 } 00888 00889 free_vector_u32(node_labels); 00890 free_matrix_i32(node_markers); 00891 00892 return NULL; 00893 } 00894 00895 00903 Error* train_nb_gmm_model_tree 00904 ( 00905 NB_gmm_model_tree** tree_out, 00906 const Vector_u32* labels, 00907 const Matrix_i32* markers, 00908 const char* tree_xml_fname, 00909 const char* tree_dtd_fname 00910 ) 00911 { 00912 xmlDoc* xml_doc; 00913 xmlNode* xml_node; 00914 xmlNode* it; 00915 Error* err; 00916 int slen = 256; 00917 char str[slen]; 00918 00919 assert(tree_xml_fname); 00920 00921 if ((err = read_nb_gmm_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname))) 00922 { 00923 return err; 00924 } 00925 00926 for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next) 00927 { 00928 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model")) 00929 { 00930 xml_node = it; 00931 break; 00932 } 00933 } 00934 00935 if ((err = create_nb_gmm_model_tree_from_xml_node(tree_out, NULL, 0, 00936 xml_node))) 00937 { 00938 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 00939 return JWSC_EARG(str); 00940 } 00941 00942 if ((err = train_nb_gmm_model_node(*tree_out, labels, markers))) 00943 { 00944 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 00945 return JWSC_EARG(str); 00946 } 00947 00948 return NULL; 00949 } 00950 00951 00953 static Error* recursively_predict_label_in_model_tree 00954 ( 00955 uint32_t* label_out, 00956 double* confidence_out, 00957 const NB_gmm_model_tree* tree, 00958 const Vector_i32* markers_v, 00959 uint32_t order 00960 ) 00961 { 00962 uint32_t subtree_label; 00963 uint32_t i; 00964 double confidence; 00965 Error* err; 00966 00967 if ((err = predict_label_with_nb_gmm_model(label_out, &confidence, 00968 markers_v, tree->model, order)) != NULL) 00969 { 00970 return err; 00971 } 00972 *confidence_out *= confidence; 00973 00974 for (i = 0; i < tree->num_groups; i++) 00975 { 00976 if (tree->subtrees[ i ]) 00977 { 00978 subtree_label = tree->subtrees[ i ]->parent_label; 00979 00980 if (subtree_label == *label_out) 00981 { 00982 if ((err = recursively_predict_label_in_model_tree(label_out, 00983 confidence_out, tree->subtrees[ i ], markers_v, 00984 order)) != NULL) 00985 { 00986 return err; 00987 } 00988 } 00989 } 00990 } 00991 00992 return NULL; 00993 } 00994 00995 01011 Error* predict_labels_with_nb_gmm_model_tree 01012 ( 01013 Vector_u32** labels_out, 01014 Vector_d** confidence_out, 01015 const Matrix_i32* markers, 01016 const NB_gmm_model_tree* tree, 01017 uint32_t order 01018 ) 01019 { 01020 uint32_t s, num_samples; 01021 uint32_t m, num_markers; 01022 01023 Vector_u32* labels; 01024 Vector_d* confidence; 01025 Vector_i32* markers_v; 01026 Error* err; 01027 01028 num_samples = markers->num_rows; 01029 num_markers = markers->num_cols; 01030 01031 create_vector_u32(labels_out, num_samples); 01032 labels = *labels_out; 01033 01034 create_init_vector_d(confidence_out, num_samples, 1.0); 01035 confidence = *confidence_out; 01036 01037 markers_v = NULL; 01038 create_vector_i32(&markers_v, num_markers); 01039 01040 for (s = 0; s < num_samples; s++) 01041 { 01042 for (m = 0; m < num_markers; m++) 01043 { 01044 markers_v->elts[ m ]= markers->elts[ s ][ m ]; 01045 } 01046 01047 if ((err = recursively_predict_label_in_model_tree( 01048 &(labels->elts[ s ]), 01049 &(confidence->elts[ s ]), tree, markers_v, order)) 01050 != NULL) 01051 { 01052 return err; 01053 } 01054 } 01055 01056 free_vector_i32(markers_v); 01057 01058 return NULL; 01059 } 01060 01061 01063 static Error* read_nb_gmm_model_node 01064 ( 01065 NB_gmm_model_node* node, 01066 const char* model_dirname 01067 ) 01068 { 01069 uint32_t i; 01070 char buf[1024] = {0}; 01071 Error* err; 01072 01073 snprintf(buf, 1024, "%s/%s", model_dirname, node->model_fname); 01074 if ((err = read_nb_gmm_model(&(node->model), buf))) 01075 { 01076 return err; 01077 } 01078 01079 for (i = 0; i < node->num_groups; i++) 01080 { 01081 if (node->subtrees[ i ]) 01082 { 01083 if ((err = read_nb_gmm_model_node(node->subtrees[ i ], 01084 model_dirname))) 01085 { 01086 return err; 01087 } 01088 } 01089 } 01090 01091 return NULL; 01092 } 01093 01094 01101 Error* read_nb_gmm_model_tree 01102 ( 01103 NB_gmm_model_tree** tree_out, 01104 const char* tree_xml_fname, 01105 const char* tree_dtd_fname, 01106 const char* model_dirname 01107 ) 01108 { 01109 xmlDoc* xml_doc; 01110 xmlNode* xml_node; 01111 xmlNode* it; 01112 Error* err; 01113 int slen = 256; 01114 char str[slen]; 01115 01116 assert(tree_xml_fname && model_dirname); 01117 01118 if ((err = read_nb_gmm_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname))) 01119 { 01120 return err; 01121 } 01122 01123 for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next) 01124 { 01125 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model")) 01126 { 01127 xml_node = it; 01128 break; 01129 } 01130 } 01131 01132 if ((err = create_nb_gmm_model_tree_from_xml_node(tree_out, NULL, 0, 01133 xml_node))) 01134 { 01135 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 01136 return JWSC_EARG(str); 01137 } 01138 01139 if ((err = read_nb_gmm_model_node(*tree_out, model_dirname))) 01140 { 01141 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 01142 return JWSC_EARG(str); 01143 } 01144 01145 return NULL; 01146 } 01147 01148 01153 Error* write_nb_gmm_model_tree 01154 ( 01155 const NB_gmm_model_tree* tree, 01156 const char* model_dirname 01157 ) 01158 { 01159 uint32_t i; 01160 char buf[1024] = {0}; 01161 Error* err; 01162 int slen = 256; 01163 char str[slen]; 01164 01165 snprintf(buf, 1024, "%s/%s", model_dirname, tree->model_fname); 01166 if ((err = write_nb_gmm_model(tree->model, buf))) 01167 { 01168 return err; 01169 } 01170 01171 for (i = 0; i < tree->num_groups; i++) 01172 { 01173 if (tree->subtrees[ i ]) 01174 { 01175 if ((err = write_nb_gmm_model_tree(tree->subtrees[ i ], 01176 model_dirname))) 01177 { 01178 snprintf(str, slen, "%s: %s", tree->model_fname, err->msg); 01179 return JWSC_EARG(str); 01180 } 01181 } 01182 } 01183 01184 return NULL; 01185 } 01186 01187 01189 void free_nb_gmm_model_tree(NB_gmm_model_tree* tree) 01190 { 01191 uint32_t i; 01192 01193 if (!tree) 01194 return; 01195 01196 for (i = 0; i < tree->num_groups; i++) 01197 { 01198 free_nb_gmm_model_tree(tree->subtrees[ i ]); 01199 free_vector_u32(tree->altlabels[ i ]); 01200 } 01201 01202 free(tree->subtrees); 01203 free(tree->altlabels); 01204 free_vector_u32(tree->labels); 01205 free_vector_d(tree->priors); 01206 free_nb_gmm_model(tree->model); 01207 free(tree->model_fname); 01208 free(tree); 01209 }