Haplo Prediction
predict haplogroups
|
00001 /* 00002 * This work is licensed under a Creative Commons 00003 * Attribution-Noncommercial-Share Alike 3.0 United States License. 00004 * 00005 * http://creativecommons.org/licenses/by-nc-sa/3.0/us/ 00006 * 00007 * You are free: 00008 * 00009 * to Share - to copy, distribute, display, and perform the work 00010 * to Remix - to make derivative works 00011 * 00012 * Under the following conditions: 00013 * 00014 * Attribution. You must attribute the work in the manner specified by the 00015 * author or licensor (but not in any way that suggests that they endorse you 00016 * or your use of the work). 00017 * 00018 * Noncommercial. You may not use this work for commercial purposes. 00019 * 00020 * Share Alike. If you alter, transform, or build upon this work, you may 00021 * distribute the resulting work only under the same or similar license to 00022 * this one. 00023 * 00024 * For any reuse or distribution, you must make clear to others the license 00025 * terms of this work. The best way to do this is by including this header. 00026 * 00027 * Any of the above conditions can be waived if you get permission from the 00028 * copyright holder. 00029 * 00030 * Apart from the remix rights granted under this license, nothing in this 00031 * license impairs or restricts the author's moral rights. 00032 */ 00033 00034 00046 #include <config.h> 00047 00048 #include <stdlib.h> 00049 #include <stdio.h> 00050 #include <inttypes.h> 00051 #include <assert.h> 00052 #include <string.h> 00053 #include <errno.h> 00054 #include <math.h> 00055 00056 #include <libxml/tree.h> 00057 #include <libxml/parser.h> 00058 #include <libxml/valid.h> 00059 00060 #ifdef HAPLO_HAVE_DMALLOC 00061 #include <dmalloc.h> 00062 #endif 00063 00064 #include <jwsc/base/error.h> 00065 #include <jwsc/base/limits.h> 00066 #include <jwsc/base/file_io.h> 00067 #include <jwsc/vector/vector.h> 00068 #include <jwsc/vector/vector_io.h> 00069 #include <jwsc/vector/vector_math.h> 00070 #include <jwsc/matrix/matrix.h> 00071 #include <jwsc/matrix/matrix_io.h> 00072 #include <jwsc/matblock/matblock.h> 00073 #include <jwsc/matblock/matblock_io.h> 00074 #include <jwsc/stat/gmm.h> 00075 00076 #include "xml.h" 00077 #include "haplo_groups.h" 00078 #include "input.h" 00079 #include "mv_gmm.h" 00080 00081 00092 void train_mv_gmm_model 00093 ( 00094 MV_gmm_model** model_out, 00095 const Vector_u32* labels, 00096 const Matrix_i32* markers, 00097 const Vector_d* label_priors, 00098 uint32_t num_components 00099 ) 00100 { 00101 uint32_t l, num_labels; 00102 uint32_t s, num_samples; 00103 uint32_t m, num_markers; 00104 uint32_t n, N; 00105 00106 Matrix_d** mu; 00107 Matblock_d** sigma; 00108 Vector_d** pi; 00109 Vector_d* priors; 00110 Matrix_d* markers_d = NULL; 00111 00112 assert(markers->num_rows == labels->num_elts); 00113 00114 num_markers = markers->num_cols; 00115 num_samples = markers->num_rows; 00116 num_labels = get_num_haplo_groups(); 00117 00118 if (*model_out != NULL) 00119 { 00120 free_mv_gmm_model(*model_out); 00121 } 00122 00123 assert(mu = malloc(num_labels*sizeof(Matrix_d*))); 00124 assert(sigma = malloc(num_labels*sizeof(Matblock_d*))); 00125 assert(pi = malloc(num_labels*sizeof(Vector_d*))); 00126 priors = NULL; 00127 00128 for (l = 0; l < num_labels; l++) 00129 { 00130 mu [ l ] = NULL; 00131 sigma[ l ] = NULL; 00132 pi [ l ] = NULL; 00133 00134 N = 0; 00135 for (s = 0; s < num_samples; s++) 00136 { 00137 if (labels->elts[ s ] == l) 00138 { 00139 N++; 00140 } 00141 } 00142 00143 // TODO we should probably handle the case of N=1 by creating a 00144 // distribution with the sample as the mean and a very small diagonal 00145 // covariance matrix. 00146 if (N < num_components*2) 00147 continue; 00148 00149 create_matrix_d(&markers_d, N, num_markers); 00150 00151 n = 0; 00152 for (s = 0; s < num_samples; s++) 00153 { 00154 if (labels->elts[ s ] == l) 00155 { 00156 for (m = 0; m < num_markers; m++) 00157 { 00158 markers_d->elts[ n ][ m ] = markers->elts[ s ][ m ]; 00159 } 00160 n++; 00161 } 00162 } 00163 00164 train_gmm_d(&(mu[l]), &(sigma[l]), &(pi[l]), NULL, markers_d, 00165 num_components); 00166 } 00167 00168 free_matrix_d(markers_d); 00169 00170 if (label_priors == NULL) 00171 { 00172 create_zero_vector_d(&priors, num_labels); 00173 00174 for (s = 0; s < num_samples; s++) 00175 { 00176 l = labels->elts[ s ]; 00177 priors->elts[ l ]++; 00178 } 00179 normalize_vector_sum_d(&priors, priors); 00180 } 00181 else 00182 { 00183 normalize_vector_sum_d(&priors, label_priors); 00184 } 00185 00186 assert(*model_out = malloc(sizeof(MV_gmm_model))); 00187 (*model_out)->num_markers = num_markers; 00188 (*model_out)->num_components = num_components; 00189 (*model_out)->mu = mu; 00190 (*model_out)->sigma = sigma; 00191 (*model_out)->pi = pi; 00192 (*model_out)->priors = priors; 00193 } 00194 00195 00207 Error* predict_label_with_mv_gmm_model 00208 ( 00209 uint32_t* label_out, 00210 double* confidence_out, 00211 const Vector_i32* markers, 00212 const MV_gmm_model* model, 00213 uint32_t order 00214 ) 00215 { 00216 uint32_t num_markers; 00217 uint32_t marker; 00218 uint32_t num_labels; 00219 uint32_t label; 00220 uint32_t best_label; 00221 uint32_t i; 00222 double log_map; 00223 double log_ll; 00224 double log_prior; 00225 double log_posterior; 00226 double posterior_sum; 00227 00228 Vector_u32* best_labels = NULL; 00229 Vector_d* best_labels_conf = NULL; 00230 Matrix_d* marker_vals = NULL; 00231 00232 const Matrix_d* mu; 00233 const Matblock_d* sigma; 00234 const Vector_d* pi; 00235 00236 num_labels = get_num_haplo_groups(); 00237 num_markers = markers->num_elts; 00238 00239 log_map = log(0); 00240 posterior_sum = 0; 00241 00242 create_init_vector_u32(&best_labels, num_labels, 0); 00243 create_init_vector_d(&best_labels_conf, num_labels, 0); 00244 00245 create_matrix_d(&marker_vals, 1, num_markers); 00246 00247 for (label = 0; label < num_labels; label++) 00248 { 00249 if (!model->mu[ label ]) 00250 continue; 00251 00252 for (marker = 0; marker < num_markers; marker++) 00253 { 00254 if (markers->elts[ marker ] <= 0) 00255 { 00256 // Impute missing value with the mean. 00257 marker_vals->elts[0][ marker ] = 0; 00258 for (i = 0; i < model->num_components; i++) 00259 { 00260 marker_vals->elts[0][ marker ] += 00261 model->mu[ label ]->elts[ i ][ marker ] * 00262 model->pi[ label ]->elts[ i ]; 00263 } 00264 } 00265 else 00266 { 00267 marker_vals->elts[0][ marker ] = markers->elts[ marker ]; 00268 } 00269 } 00270 00271 mu = model->mu[ label ]; 00272 sigma = model->sigma[ label ]; 00273 pi = model->pi[ label ]; 00274 00275 log_ll = gmm_log_likelihood_d(mu, sigma, pi, marker_vals); 00276 00277 if (model->priors->elts[ label ] < JWSC_MIN_LOG_ARG) 00278 { 00279 log_prior = log(JWSC_MIN_LOG_ARG); 00280 } 00281 else 00282 { 00283 log_prior = log(model->priors->elts[ label ]); 00284 } 00285 00286 log_posterior = log_ll + log_prior; 00287 posterior_sum += exp(log_posterior); 00288 00289 if (log_posterior > log_map) 00290 { 00291 log_map = log_posterior; 00292 best_label = label; 00293 for (i = num_labels-1; i > 0; i--) 00294 { 00295 best_labels->elts[ i ] = best_labels->elts[ i - 1]; 00296 best_labels_conf->elts[ i ] = best_labels_conf->elts[ i - 1]; 00297 } 00298 best_labels->elts[ 0 ] = best_label; 00299 best_labels_conf->elts[ 0 ] = log_map; 00300 } 00301 } 00302 00303 assert(order < best_labels->num_elts); 00304 00305 *label_out = best_labels->elts[ order ]; 00306 00307 if (posterior_sum > 0) 00308 { 00309 assert(finite(*confidence_out = exp(best_labels_conf->elts[order]) / 00310 (double)posterior_sum)); 00311 } 00312 else 00313 { 00314 *confidence_out = 0; 00315 } 00316 00317 free_vector_u32(best_labels); 00318 free_vector_d(best_labels_conf); 00319 free_matrix_d(marker_vals); 00320 00321 return NULL; 00322 } 00323 00324 00338 Error* predict_labels_with_mv_gmm_model 00339 ( 00340 Vector_u32** labels_out, 00341 Vector_d** confidence_out, 00342 const Matrix_i32* markers, 00343 const MV_gmm_model* model 00344 ) 00345 { 00346 uint32_t num_samples; 00347 uint32_t sample; 00348 uint32_t num_markers; 00349 uint32_t marker; 00350 00351 Vector_u32* labels; 00352 Vector_d* confidence; 00353 Vector_i32* markers_v; 00354 Error* e; 00355 00356 num_samples = markers->num_rows; 00357 num_markers = markers->num_cols; 00358 00359 create_vector_u32(labels_out, num_samples); 00360 labels = *labels_out; 00361 00362 create_vector_d(confidence_out, num_samples); 00363 confidence = *confidence_out; 00364 00365 markers_v = NULL; 00366 create_vector_i32(&markers_v, num_markers); 00367 00368 for (sample = 0; sample < num_samples; sample++) 00369 { 00370 for (marker = 0; marker < num_markers; marker++) 00371 { 00372 markers_v->elts[ marker ] = markers->elts[ sample ][marker]; 00373 } 00374 00375 if ((e = predict_label_with_mv_gmm_model(&(labels->elts[ sample ]), 00376 &(confidence->elts[ sample ]), markers_v, model, 0)) != NULL) 00377 { 00378 return e; 00379 } 00380 } 00381 00382 free_vector_i32(markers_v); 00383 00384 return NULL; 00385 } 00386 00387 00396 Error* read_mv_gmm_model(MV_gmm_model** model_out, const char* fname) 00397 { 00398 uint32_t num_labels; 00399 uint32_t i; 00400 FILE* fp; 00401 Vector_u32* v = NULL; 00402 MV_gmm_model* model; 00403 Error* e; 00404 int slen = 256; 00405 char str[slen]; 00406 00407 num_labels = get_num_haplo_groups(); 00408 00409 if ((fp = fopen(fname, "r")) == NULL) 00410 { 00411 snprintf(str, slen, "%s: %s", fname, strerror(errno)); 00412 return JWSC_EARG(str); 00413 } 00414 00415 if (*model_out == NULL) 00416 { 00417 assert((*model_out = malloc(sizeof(MV_gmm_model))) != NULL); 00418 model = *model_out; 00419 model->priors = NULL; 00420 } 00421 00422 if (fscanf(fp, "%u\n%u\n", &(model->num_markers), 00423 &(model->num_components)) != 2) 00424 { 00425 snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file"); 00426 return JWSC_EARG(str); 00427 } 00428 00429 assert(model->mu = malloc(num_labels*sizeof(void*))); 00430 assert(model->sigma = malloc(num_labels*sizeof(void*))); 00431 assert(model->pi = malloc(num_labels*sizeof(void*))); 00432 00433 if ((e = read_vector_fp_u32(&v, fp)) || (v->num_elts != num_labels)) 00434 { 00435 return JWSC_EARG("Improperly formatted model file"); 00436 } 00437 00438 for (i = 0; i < num_labels; i++) 00439 { 00440 model->mu[ i ] = NULL; 00441 model->sigma[ i ] = NULL; 00442 model->pi[ i ] = NULL; 00443 00444 if (v->elts[ i ]) 00445 { 00446 if ((e = read_matrix_with_header_fp_d(&(model->mu[i]), fp)) || 00447 (e = read_matblock_with_header_fp_d(&(model->sigma[i]), fp)) || 00448 (e = read_vector_with_header_fp_d(&(model->pi[i]), fp))) 00449 { 00450 return JWSC_EARG("Improperly formatted model file"); 00451 } 00452 } 00453 } 00454 00455 if ((e = read_vector_with_header_fp_d(&(model->priors), fp)) != NULL) 00456 { 00457 snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file"); 00458 return JWSC_EARG(str); 00459 } 00460 00461 if (fclose(fp) != 0) 00462 { 00463 snprintf(str, slen, "%s: %s", fname, strerror(errno)); 00464 return JWSC_EARG(str); 00465 } 00466 00467 return NULL; 00468 } 00469 00470 00477 Error* write_mv_gmm_model(MV_gmm_model* model, const char* fname) 00478 { 00479 uint32_t num_labels; 00480 uint32_t i; 00481 FILE* fp; 00482 int slen = 256; 00483 char str[slen]; 00484 00485 num_labels = get_num_haplo_groups(); 00486 00487 if ((fp = fopen(fname, "w")) == NULL) 00488 { 00489 snprintf(str, slen, "%s: %s", fname, strerror(errno)); 00490 return JWSC_EARG(str); 00491 } 00492 00493 fprintf(fp, "%u\n", model->num_markers); 00494 fprintf(fp, "%u\n", model->num_components); 00495 00496 for (i = 0; i < num_labels; i++) 00497 { 00498 fprintf(fp, "%u ", (model->mu[ i ] != NULL)); 00499 } 00500 fprintf(fp, "\n"); 00501 00502 for (i = 0; i < num_labels; i++) 00503 { 00504 if (model->mu[ i ]) 00505 { 00506 write_matrix_with_header_fp_d(model->mu[ i ], fp); 00507 write_matblock_with_header_fp_d(model->sigma[ i ], fp); 00508 write_vector_with_header_fp_d(model->pi[ i ], fp); 00509 } 00510 } 00511 00512 write_vector_with_header_fp_d(model->priors, fp); 00513 00514 if (fclose(fp) != 0) 00515 { 00516 snprintf(str, slen, "%s: %s", fname, strerror(errno)); 00517 return JWSC_EARG(str); 00518 } 00519 00520 return NULL; 00521 } 00522 00523 00525 void free_mv_gmm_model(MV_gmm_model* model) 00526 { 00527 uint32_t i; 00528 uint32_t num_labels; 00529 00530 if (!model) 00531 return; 00532 00533 num_labels = get_num_haplo_groups(); 00534 00535 for (i = 0; i < num_labels; i++) 00536 { 00537 free_matrix_d(model->mu[ i ]); 00538 free_matblock_d(model->sigma[ i ]); 00539 free_vector_d(model->pi[ i ]); 00540 } 00541 free(model->mu); 00542 free(model->sigma); 00543 free(model->pi); 00544 free_vector_d(model->priors); 00545 free(model); 00546 } 00547 00548 00550 static Error* read_mv_gmm_xml_doc 00551 ( 00552 xmlDoc** xml_doc_out, 00553 const char* xml_fname, 00554 const char* dtd_fname 00555 ) 00556 { 00557 xmlParserCtxt* xml_parse_ctxt; 00558 xmlValidCtxt* xml_valid_ctxt; 00559 xmlDtd* xml_dtd; 00560 int slen = 256; 00561 char str[slen]; 00562 00563 assert(xml_parse_ctxt = xmlNewParserCtxt()); 00564 00565 if (!(*xml_doc_out = xmlCtxtReadFile(xml_parse_ctxt, xml_fname, NULL, 0))) 00566 { 00567 snprintf(str, slen, "%s: %s", xml_fname, "Could not parse file"); 00568 return JWSC_EARG(str); 00569 } 00570 00571 xmlFreeParserCtxt(xml_parse_ctxt); 00572 00573 if (dtd_fname) 00574 { 00575 assert(xml_valid_ctxt = xmlNewValidCtxt()); 00576 00577 if (!(xml_dtd = xmlParseDTD(NULL, (xmlChar*)dtd_fname))) 00578 { 00579 snprintf(str, slen, "%s: %s", dtd_fname, "Could not parse DTD"); 00580 return JWSC_EARG(str); 00581 } 00582 00583 if (!xmlValidateDtd(xml_valid_ctxt, *xml_doc_out, xml_dtd)) 00584 { 00585 snprintf(str, slen, "%s: %s", xml_fname, "XML file not valid"); 00586 return JWSC_EARG(str); 00587 } 00588 00589 xmlFreeValidCtxt(xml_valid_ctxt); 00590 xmlFreeDtd(xml_dtd); 00591 } 00592 00593 return NULL; 00594 } 00595 00596 00603 static Error* create_mv_gmm_model_tree_from_xml_node 00604 ( 00605 MV_gmm_model_tree** tree_out, 00606 const MV_gmm_model_tree* parent, 00607 uint32_t parent_label, 00608 xmlNode* xml_node 00609 ) 00610 { 00611 MV_gmm_model_tree* tree; 00612 uint32_t i, j; 00613 uint32_t label; 00614 uint32_t altlabel; 00615 uint32_t num_altlabels; 00616 double p; 00617 xmlAttr* attr; 00618 xmlNode* it; 00619 xmlNode* itt; 00620 Error* err; 00621 const char* fname; 00622 00623 if (*tree_out) 00624 { 00625 free_mv_gmm_model_tree(*tree_out); 00626 } 00627 00628 assert(*tree_out = malloc(sizeof(MV_gmm_model_tree))); 00629 tree = *tree_out; 00630 00631 tree->parent = parent; 00632 tree->parent_label = parent_label; 00633 tree->num_groups = 0; 00634 tree->subtrees = NULL; 00635 tree->labels = NULL; 00636 tree->altlabels = NULL; 00637 tree->priors = NULL; 00638 tree->model = NULL; 00639 tree->model_fname = NULL; 00640 00641 assert(XMLStrEqual(xml_node->name, "model")); 00642 00643 for (it = xml_node->children; it; it = it->next) 00644 { 00645 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "group")) 00646 { 00647 tree->num_groups++; 00648 } 00649 } 00650 00651 assert(tree->subtrees = calloc(tree->num_groups, sizeof(void*))); 00652 assert(tree->altlabels = calloc(tree->num_groups, sizeof(void*))); 00653 create_vector_u32(&(tree->labels), tree->num_groups); 00654 00655 for (attr = xml_node->properties; attr; attr = attr->next) 00656 { 00657 if (XMLStrEqual(attr->name, "priors") && 00658 XMLStrEqual(attr->children->content, "true")) 00659 { 00660 create_init_vector_d(&(tree->priors), get_num_haplo_groups(), 1.0); 00661 break; 00662 } 00663 } 00664 00665 i = 0; 00666 for (it = xml_node->children; it; it = it->next) 00667 { 00668 j = 0; 00669 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "file")) 00670 { 00671 fname = (const char*) it->children->content; 00672 tree->model_fname = malloc((strlen(fname)+1) * sizeof(char)); 00673 strcpy(tree->model_fname, fname); 00674 } 00675 else if (it->type == XML_ELEMENT_NODE && 00676 XMLStrEqual(it->name, "mixture-components")) 00677 { 00678 if (sscanf((char*)it->children->content, "%d", 00679 &(tree->num_components)) != 1 || 00680 tree->num_components < 1) 00681 { 00682 return JWSC_EARG("Invalid number of mixture components"); 00683 } 00684 } 00685 else if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "group")) 00686 { 00687 num_altlabels = 0; 00688 for (itt = it->children; itt; itt = itt->next) 00689 { 00690 if (itt->type == XML_ELEMENT_NODE && 00691 XMLStrEqual(itt->name, "altlabel")) 00692 { 00693 num_altlabels++; 00694 } 00695 } 00696 if (num_altlabels) 00697 { 00698 create_vector_u32(&(tree->altlabels[ i ]), num_altlabels); 00699 } 00700 00701 for (itt = it->children; itt; itt = itt->next) 00702 { 00703 if (itt->type == XML_ELEMENT_NODE && 00704 XMLStrEqual(itt->name, "label")) 00705 { 00706 if ((err = lookup_haplo_group_index_from_label(&label, 00707 (const char*) itt->children->content))) 00708 { 00709 return err; 00710 } 00711 tree->labels->elts[ i ] = label; 00712 } 00713 else if (itt->type == XML_ELEMENT_NODE && 00714 XMLStrEqual(itt->name, "altlabel")) 00715 { 00716 if ((err = lookup_haplo_group_index_from_label(&altlabel, 00717 (const char*) itt->children->content))) 00718 { 00719 return err; 00720 } 00721 tree->altlabels[ i ]->elts[ j++ ] = altlabel; 00722 } 00723 else if (tree->priors && itt->type == XML_ELEMENT_NODE && 00724 XMLStrEqual(itt->name, "prior")) 00725 { 00726 if (sscanf((char*)itt->children->content, "%lf", &p) != 1 || 00727 p < 0 || p > 1) 00728 { 00729 return JWSC_EARG("Invalid prior"); 00730 } 00731 tree->priors->elts[ i ] = p; 00732 } 00733 else if (itt->type == XML_ELEMENT_NODE && 00734 XMLStrEqual(itt->name, "model")) 00735 { 00736 if ((err = create_mv_gmm_model_tree_from_xml_node( 00737 &(tree->subtrees[ i ]), tree, label, itt))) 00738 { 00739 return err; 00740 } 00741 } 00742 } 00743 i++; 00744 } 00745 } 00746 assert(i == tree->num_groups); 00747 00748 return NULL; 00749 } 00750 00751 00756 static Error* create_model_training_data 00757 ( 00758 Vector_u32** train_labels_out, 00759 Matrix_i32** train_markers_out, 00760 const Vector_u32* data_labels, 00761 const Matrix_i32* data_markers, 00762 const Vector_u32* model_labels, 00763 Vector_u32*const* model_altlabels 00764 ) 00765 { 00766 uint8_t b; 00767 uint32_t i, j, k; 00768 uint32_t n; 00769 Vector_u32* train_labels = NULL; 00770 Matrix_i32* train_markers = NULL; 00771 00772 copy_vector_u32(&train_labels, data_labels); 00773 copy_matrix_i32(&train_markers, data_markers); 00774 00775 n = 0; 00776 for (i = 0; i < data_labels->num_elts; i++) 00777 { 00778 for (j = 0; j < model_labels->num_elts; j++) 00779 { 00780 if (is_ancestor(data_labels->elts[ i ], model_labels->elts[ j ])) 00781 { 00782 train_labels->elts[ n ] = model_labels->elts[ j ]; 00783 copy_matrix_block_into_matrix_i32(train_markers, n, 0, 00784 data_markers, i, 0, 1, data_markers->num_cols); 00785 n++; 00786 break; 00787 } 00788 else if (model_altlabels[ j ]) 00789 { 00790 for (k = 0; k < model_altlabels[ j ]->num_elts; k++) 00791 { 00792 if ((b = is_ancestor(data_labels->elts[ i ], 00793 model_altlabels[ j ]->elts[ k ]))) 00794 { 00795 train_labels->elts[ n ] = model_labels->elts[ j ]; 00796 copy_matrix_block_into_matrix_i32(train_markers, n, 0, 00797 data_markers, i, 0, 1, data_markers->num_cols); 00798 n++; 00799 break; 00800 } 00801 } 00802 if (b) 00803 { 00804 break; 00805 } 00806 } 00807 } 00808 } 00809 00810 if (!n) 00811 { 00812 return JWSC_EARG("No data for model"); 00813 } 00814 00815 copy_vector_section_u32(train_labels_out, train_labels, 0, n); 00816 copy_matrix_block_i32(train_markers_out, train_markers, 0, 0, n, 00817 train_markers->num_cols); 00818 00819 free_vector_u32(train_labels); 00820 free_matrix_i32(train_markers); 00821 00822 return NULL; 00823 } 00824 00825 00827 static Error* train_mv_gmm_model_node 00828 ( 00829 MV_gmm_model_node* node, 00830 const Vector_u32* labels, 00831 const Matrix_i32* markers 00832 ) 00833 { 00834 uint32_t i; 00835 Vector_u32* node_labels = NULL; 00836 Matrix_i32* node_markers = NULL; 00837 Error* err; 00838 int slen = 256; 00839 char str[slen]; 00840 00841 if ((err = create_model_training_data(&node_labels, &node_markers, labels, 00842 markers, node->labels, node->altlabels))) 00843 { 00844 snprintf(str, slen, "%s: %s", node->model_fname, err->msg); 00845 return JWSC_EARG(str); 00846 } 00847 00848 train_mv_gmm_model(&(node->model), node_labels, node_markers, 00849 node->priors, node->num_components); 00850 00851 for (i = 0; i < node->num_groups; i++) 00852 { 00853 if (node->subtrees[ i ]) 00854 { 00855 if ((err = train_mv_gmm_model_node(node->subtrees[ i ], labels, 00856 markers))) 00857 { 00858 return err; 00859 } 00860 } 00861 } 00862 00863 free_vector_u32(node_labels); 00864 free_matrix_i32(node_markers); 00865 00866 return NULL; 00867 } 00868 00869 00877 Error* train_mv_gmm_model_tree 00878 ( 00879 MV_gmm_model_tree** tree_out, 00880 const Vector_u32* labels, 00881 const Matrix_i32* markers, 00882 const char* tree_xml_fname, 00883 const char* tree_dtd_fname 00884 ) 00885 { 00886 xmlDoc* xml_doc; 00887 xmlNode* xml_node; 00888 xmlNode* it; 00889 Error* err; 00890 int slen = 256; 00891 char str[slen]; 00892 00893 assert(tree_xml_fname); 00894 00895 if ((err = read_mv_gmm_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname))) 00896 { 00897 return err; 00898 } 00899 00900 for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next) 00901 { 00902 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model")) 00903 { 00904 xml_node = it; 00905 break; 00906 } 00907 } 00908 00909 if ((err = create_mv_gmm_model_tree_from_xml_node(tree_out, NULL, 0, 00910 xml_node))) 00911 { 00912 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 00913 return JWSC_EARG(str); 00914 } 00915 00916 if ((err = train_mv_gmm_model_node(*tree_out, labels, markers))) 00917 { 00918 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 00919 return JWSC_EARG(str); 00920 } 00921 00922 return NULL; 00923 } 00924 00925 00927 static Error* recursively_predict_label_in_model_tree 00928 ( 00929 uint32_t* label_out, 00930 double* confidence_out, 00931 const MV_gmm_model_tree* tree, 00932 const Vector_i32* markers_v, 00933 uint32_t order 00934 ) 00935 { 00936 uint32_t subtree_label; 00937 uint32_t i; 00938 double confidence; 00939 Error* err; 00940 00941 if ((err = predict_label_with_mv_gmm_model(label_out, &confidence, 00942 markers_v, tree->model, order)) != NULL) 00943 { 00944 return err; 00945 } 00946 *confidence_out *= confidence; 00947 00948 for (i = 0; i < tree->num_groups; i++) 00949 { 00950 if (tree->subtrees[ i ]) 00951 { 00952 subtree_label = tree->subtrees[ i ]->parent_label; 00953 00954 if (subtree_label == *label_out) 00955 { 00956 if ((err = recursively_predict_label_in_model_tree(label_out, 00957 confidence_out, tree->subtrees[ i ], markers_v, 00958 order)) != NULL) 00959 { 00960 return err; 00961 } 00962 } 00963 } 00964 } 00965 00966 return NULL; 00967 } 00968 00969 00985 Error* predict_labels_with_mv_gmm_model_tree 00986 ( 00987 Vector_u32** labels_out, 00988 Vector_d** confidence_out, 00989 const Matrix_i32* markers, 00990 const MV_gmm_model_tree* tree, 00991 uint32_t order 00992 ) 00993 { 00994 uint32_t s, num_samples; 00995 uint32_t m, num_markers; 00996 00997 Vector_u32* labels; 00998 Vector_d* confidence; 00999 Vector_i32* markers_v; 01000 Error* err; 01001 01002 num_samples = markers->num_rows; 01003 num_markers = markers->num_cols; 01004 01005 create_vector_u32(labels_out, num_samples); 01006 labels = *labels_out; 01007 01008 create_init_vector_d(confidence_out, num_samples, 1.0); 01009 confidence = *confidence_out; 01010 01011 markers_v = NULL; 01012 create_vector_i32(&markers_v, num_markers); 01013 01014 for (s = 0; s < num_samples; s++) 01015 { 01016 for (m = 0; m < num_markers; m++) 01017 { 01018 markers_v->elts[ m ]= markers->elts[ s ][ m ]; 01019 } 01020 01021 if ((err = recursively_predict_label_in_model_tree( 01022 &(labels->elts[ s ]), 01023 &(confidence->elts[ s ]), tree, markers_v, order)) 01024 != NULL) 01025 { 01026 return err; 01027 } 01028 } 01029 01030 free_vector_i32(markers_v); 01031 01032 return NULL; 01033 } 01034 01035 01037 static Error* read_mv_gmm_model_node 01038 ( 01039 MV_gmm_model_node* node, 01040 const char* model_dirname 01041 ) 01042 { 01043 uint32_t i; 01044 char buf[1024] = {0}; 01045 Error* err; 01046 01047 snprintf(buf, 1024, "%s/%s", model_dirname, node->model_fname); 01048 if ((err = read_mv_gmm_model(&(node->model), buf))) 01049 { 01050 return err; 01051 } 01052 01053 for (i = 0; i < node->num_groups; i++) 01054 { 01055 if (node->subtrees[ i ]) 01056 { 01057 if ((err = read_mv_gmm_model_node(node->subtrees[ i ], 01058 model_dirname))) 01059 { 01060 return err; 01061 } 01062 } 01063 } 01064 01065 return NULL; 01066 } 01067 01068 01075 Error* read_mv_gmm_model_tree 01076 ( 01077 MV_gmm_model_tree** tree_out, 01078 const char* tree_xml_fname, 01079 const char* tree_dtd_fname, 01080 const char* model_dirname 01081 ) 01082 { 01083 xmlDoc* xml_doc; 01084 xmlNode* xml_node; 01085 xmlNode* it; 01086 Error* err; 01087 int slen = 256; 01088 char str[slen]; 01089 01090 assert(tree_xml_fname && model_dirname); 01091 01092 if ((err = read_mv_gmm_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname))) 01093 { 01094 return err; 01095 } 01096 01097 for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next) 01098 { 01099 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model")) 01100 { 01101 xml_node = it; 01102 break; 01103 } 01104 } 01105 01106 if ((err = create_mv_gmm_model_tree_from_xml_node(tree_out, NULL, 0, 01107 xml_node))) 01108 { 01109 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 01110 return JWSC_EARG(str); 01111 } 01112 01113 if ((err = read_mv_gmm_model_node(*tree_out, model_dirname))) 01114 { 01115 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 01116 return JWSC_EARG(str); 01117 } 01118 01119 return NULL; 01120 } 01121 01122 01127 Error* write_mv_gmm_model_tree 01128 ( 01129 const MV_gmm_model_tree* tree, 01130 const char* model_dirname 01131 ) 01132 { 01133 uint32_t i; 01134 char buf[1024] = {0}; 01135 Error* err; 01136 int slen = 256; 01137 char str[slen]; 01138 01139 snprintf(buf, 1024, "%s/%s", model_dirname, tree->model_fname); 01140 if ((err = write_mv_gmm_model(tree->model, buf))) 01141 { 01142 return err; 01143 } 01144 01145 for (i = 0; i < tree->num_groups; i++) 01146 { 01147 if (tree->subtrees[ i ]) 01148 { 01149 if ((err = write_mv_gmm_model_tree(tree->subtrees[ i ], 01150 model_dirname))) 01151 { 01152 snprintf(str, slen, "%s: %s", tree->model_fname, err->msg); 01153 return JWSC_EARG(str); 01154 } 01155 } 01156 } 01157 01158 return NULL; 01159 } 01160 01161 01163 void free_mv_gmm_model_tree(MV_gmm_model_tree* tree) 01164 { 01165 uint32_t i; 01166 01167 if (!tree) 01168 return; 01169 01170 for (i = 0; i < tree->num_groups; i++) 01171 { 01172 free_mv_gmm_model_tree(tree->subtrees[ i ]); 01173 free_vector_u32(tree->altlabels[ i ]); 01174 } 01175 01176 free(tree->subtrees); 01177 free(tree->altlabels); 01178 free_vector_u32(tree->labels); 01179 free_vector_d(tree->priors); 01180 free_mv_gmm_model(tree->model); 01181 free(tree->model_fname); 01182 free(tree); 01183 }