Haplo Prediction
predict haplogroups
|
00001 /* 00002 * This work is licensed under a Creative Commons 00003 * Attribution-Noncommercial-Share Alike 3.0 United States License. 00004 * 00005 * http://creativecommons.org/licenses/by-nc-sa/3.0/us/ 00006 * 00007 * You are free: 00008 * 00009 * to Share - to copy, distribute, display, and perform the work 00010 * to Remix - to make derivative works 00011 * 00012 * Under the following conditions: 00013 * 00014 * Attribution. You must attribute the work in the manner specified by the 00015 * author or licensor (but not in any way that suggests that they endorse you 00016 * or your use of the work). 00017 * 00018 * Noncommercial. You may not use this work for commercial purposes. 00019 * 00020 * Share Alike. If you alter, transform, or build upon this work, you may 00021 * distribute the resulting work only under the same or similar license to 00022 * this one. 00023 * 00024 * For any reuse or distribution, you must make clear to others the license 00025 * terms of this work. The best way to do this is by including this header. 00026 * 00027 * Any of the above conditions can be waived if you get permission from the 00028 * copyright holder. 00029 * 00030 * Apart from the remix rights granted under this license, nothing in this 00031 * license impairs or restricts the author's moral rights. 00032 */ 00033 00034 00046 #include <config.h> 00047 00048 #include <stdlib.h> 00049 #include <stdio.h> 00050 #include <inttypes.h> 00051 #include <assert.h> 00052 #include <string.h> 00053 #include <errno.h> 00054 #include <math.h> 00055 00056 #include <libxml/tree.h> 00057 #include <libxml/parser.h> 00058 #include <libxml/valid.h> 00059 00060 #ifdef HAPLO_HAVE_DMALLOC 00061 #include <dmalloc.h> 00062 #endif 00063 00064 #include <jwsc/base/error.h> 00065 #include <jwsc/base/limits.h> 00066 #include <jwsc/base/file_io.h> 00067 #include <jwsc/vector/vector.h> 00068 #include <jwsc/vector/vector_io.h> 00069 #include <jwsc/vector/vector_math.h> 00070 #include <jwsc/matrix/matrix.h> 00071 #include <jwsc/matrix/matrix_io.h> 00072 00073 #include "xml.h" 00074 #include "haplo_groups.h" 00075 #include "nb_freq.h" 00076 00077 00089 void train_nb_freq_model 00090 ( 00091 NB_freq_model** model_out, 00092 const Vector_u32* labels, 00093 const Matrix_i32* markers, 00094 const Vector_d* label_priors 00095 ) 00096 { 00097 uint32_t num_labels; 00098 uint32_t num_samples; 00099 uint32_t sample; 00100 uint32_t marker; 00101 uint32_t label; 00102 uint32_t i; 00103 uint32_t num_markers; 00104 uint32_t num_marker_vals; 00105 uint32_t min_marker_val; 00106 uint32_t max_marker_val; 00107 int32_t marker_val; 00108 double sum; 00109 00110 Matrix_d* marker_given_label; 00111 Vector_d* priors; 00112 00113 assert(markers->num_rows == labels->num_elts); 00114 00115 num_labels = get_num_haplo_groups(); 00116 00117 if (*model_out != NULL) 00118 { 00119 marker_given_label = (*model_out)->marker_given_label; 00120 priors = (*model_out)->priors; 00121 } 00122 else 00123 { 00124 marker_given_label = NULL; 00125 priors = NULL; 00126 } 00127 00128 num_markers = markers->num_cols; 00129 num_samples = markers->num_rows; 00130 00131 max_marker_val = 1; 00132 min_marker_val = 1; 00133 00134 for (sample = 0; sample < num_samples; sample++) 00135 { 00136 for (marker = 0; marker < num_markers; marker++) 00137 { 00138 if (markers->elts[ sample ][ marker ] > (int32_t)min_marker_val) 00139 { 00140 if (markers->elts[ sample ][ marker ] > (int32_t)max_marker_val) 00141 { 00142 max_marker_val = markers->elts[ sample ][ marker ]; 00143 } 00144 } 00145 } 00146 } 00147 00148 num_marker_vals = max_marker_val - min_marker_val + 1; 00149 00150 create_zero_matrix_d(&marker_given_label, num_labels, 00151 num_marker_vals*num_markers); 00152 00153 for (sample = 0; sample < num_samples; sample++) 00154 { 00155 label = labels->elts[ sample ]; 00156 00157 for (marker = 0; marker < num_markers; marker++) 00158 { 00159 marker_val = markers->elts[ sample ][ marker ]; 00160 00161 /* Ignore non-positive marker values. */ 00162 if (marker_val > min_marker_val) 00163 { 00164 i = num_marker_vals*marker + 00165 ((uint32_t)marker_val - min_marker_val); 00166 00167 marker_given_label->elts[ label ][ i ]++; 00168 } 00169 } 00170 } 00171 00172 /* If a label doesn't have any representation in the data, leave the 00173 * entire row of marker_given_label zero for sparse reading and writing. 00174 * Otherwise, if a label does appear, set all it's "zero" counts to a 00175 * small value so that there is no zero probability, just a small value. */ 00176 for (label = 0; label < num_labels; label++) 00177 { 00178 sum = 0; 00179 00180 for (i = 0; i < marker_given_label->num_cols; i++) 00181 { 00182 sum += marker_given_label->elts[ label ][ i ]; 00183 } 00184 00185 if (sum) 00186 { 00187 for (i = 0; i < marker_given_label->num_cols; i++) 00188 { 00189 if (marker_given_label->elts[ label ][ i ] == 0) 00190 { 00191 marker_given_label->elts[ label ][ i ] = 0.001; 00192 } 00193 } 00194 } 00195 } 00196 00197 for (label = 0; label < num_labels; label++) 00198 { 00199 for (marker = 0; marker < num_markers; marker++) 00200 { 00201 sum = 0; 00202 00203 for (marker_val = min_marker_val; marker_val < num_marker_vals; 00204 marker_val++) 00205 { 00206 i = num_marker_vals*marker + 00207 ((uint32_t)marker_val - min_marker_val); 00208 00209 sum += marker_given_label->elts[ label ][ i ]; 00210 } 00211 00212 if (sum == 0) 00213 break; 00214 00215 sum = 1.0 / sum; 00216 00217 for (marker_val = min_marker_val; marker_val < num_marker_vals; 00218 marker_val++) 00219 { 00220 i = num_marker_vals*marker + 00221 ((uint32_t)marker_val - min_marker_val); 00222 00223 marker_given_label->elts[ label ][ i ] *= sum; 00224 } 00225 } 00226 } 00227 00228 if (label_priors == NULL) 00229 { 00230 create_zero_vector_d(&priors, num_labels); 00231 00232 for (sample = 0; sample < num_samples; sample++) 00233 { 00234 label = labels->elts[ sample ]; 00235 priors->elts[ label ]++; 00236 } 00237 normalize_vector_sum_d(&priors, priors); 00238 } 00239 else 00240 { 00241 normalize_vector_sum_d(&priors, label_priors); 00242 } 00243 00244 00245 if (*model_out == NULL) 00246 { 00247 assert(*model_out = malloc(sizeof(NB_freq_model))); 00248 } 00249 (*model_out)->num_markers = num_markers; 00250 (*model_out)->min_marker_val = min_marker_val; 00251 (*model_out)->num_marker_vals = num_marker_vals; 00252 (*model_out)->marker_given_label = marker_given_label; 00253 (*model_out)->priors = priors; 00254 } 00255 00256 00268 Error* predict_label_with_nb_freq_model 00269 ( 00270 uint32_t* label_out, 00271 double* confidence_out, 00272 const Vector_i32* markers, 00273 const NB_freq_model* model, 00274 uint32_t order 00275 ) 00276 { 00277 uint32_t num_markers; 00278 uint32_t marker; 00279 uint32_t num_labels; 00280 uint32_t label; 00281 uint32_t best_label; 00282 uint32_t i; 00283 int32_t marker_val; 00284 00285 double map; 00286 double likelihood; 00287 double prior; 00288 double posterior; 00289 double posterior_sum; 00290 00291 Vector_u32* best_labels = NULL; 00292 Vector_d* best_labels_conf = NULL; 00293 00294 num_labels = get_num_haplo_groups(); 00295 num_markers = markers->num_elts; 00296 00297 map = 0; 00298 posterior_sum = 0; 00299 00300 create_init_vector_u32(&best_labels, num_labels, 0); 00301 create_init_vector_d(&best_labels_conf, num_labels, 0); 00302 00303 for (label = 0; label < num_labels; label++) 00304 { 00305 likelihood = 1.0; 00306 00307 for (marker = 0; marker < num_markers; marker++) 00308 { 00309 marker_val = markers->elts[ marker ]; 00310 00311 /* Ignore non-positive or invalid marker values. */ 00312 if (marker_val > 0 && marker_val <= 00313 (model->num_marker_vals + model->min_marker_val)) 00314 { 00315 i = model->num_marker_vals*marker + 00316 ((uint32_t)marker_val - model->min_marker_val); 00317 00318 likelihood *= model->marker_given_label->elts[ label ][ i ]; 00319 } 00320 } 00321 00322 prior = model->priors->elts[ label ]; 00323 00324 assert(finite(posterior = likelihood * prior)); 00325 posterior_sum += posterior; 00326 00327 if (posterior > map) 00328 { 00329 map = posterior; 00330 best_label = label; 00331 for (i = num_labels-1; i > 0; i--) 00332 { 00333 best_labels->elts[ i ] = best_labels->elts[ i - 1]; 00334 best_labels_conf->elts[ i ] = best_labels_conf->elts[ i - 1]; 00335 } 00336 best_labels->elts[ 0 ] = best_label; 00337 best_labels_conf->elts[ 0 ] = map; 00338 } 00339 } 00340 00341 assert(order < best_labels->num_elts); 00342 00343 *label_out = best_labels->elts[ order ]; 00344 00345 if (posterior_sum > 0) 00346 { 00347 assert(finite(*confidence_out = best_labels_conf->elts[order] / 00348 (double)posterior_sum)); 00349 } 00350 else 00351 { 00352 *confidence_out = 0; 00353 } 00354 00355 free_vector_u32(best_labels); 00356 free_vector_d(best_labels_conf); 00357 00358 return NULL; 00359 } 00360 00361 00375 Error* predict_labels_with_nb_freq_model 00376 ( 00377 Vector_u32** labels_out, 00378 Vector_d** confidence_out, 00379 const Matrix_i32* markers, 00380 const NB_freq_model* model 00381 ) 00382 { 00383 uint32_t num_samples; 00384 uint32_t sample; 00385 uint32_t num_markers; 00386 uint32_t marker; 00387 00388 Vector_u32* labels; 00389 Vector_d* confidence; 00390 Vector_i32* markers_v; 00391 Error* e; 00392 00393 num_samples = markers->num_rows; 00394 num_markers = markers->num_cols; 00395 00396 create_vector_u32(labels_out, num_samples); 00397 labels = *labels_out; 00398 00399 create_vector_d(confidence_out, num_samples); 00400 confidence = *confidence_out; 00401 00402 markers_v = NULL; 00403 create_vector_i32(&markers_v, num_markers); 00404 00405 for (sample = 0; sample < num_samples; sample++) 00406 { 00407 for (marker = 0; marker < num_markers; marker++) 00408 { 00409 markers_v->elts[ marker ] = markers->elts[ sample ][marker]; 00410 } 00411 00412 if ((e = predict_label_with_nb_freq_model(&(labels->elts[ sample ]), 00413 &(confidence->elts[ sample ]), markers_v, model, 0)) != NULL) 00414 { 00415 return e; 00416 } 00417 } 00418 00419 free_vector_i32(markers_v); 00420 00421 return NULL; 00422 } 00423 00424 00433 Error* read_nb_freq_model(NB_freq_model** model_out, const char* fname) 00434 { 00435 FILE* fp; 00436 uint32_t label; 00437 uint32_t num_labels; 00438 uint32_t num_cols; 00439 uint32_t s; 00440 00441 Vector_u32* sparse = NULL; 00442 Matrix_d* sparse_marker_given_label = NULL; 00443 NB_freq_model* model; 00444 Error* e; 00445 int slen = 256; 00446 char str[slen]; 00447 00448 if ((fp = fopen(fname, "r")) == NULL) 00449 { 00450 snprintf(str, slen, "%s: %s", fname, strerror(errno)); 00451 return JWSC_EARG(str); 00452 } 00453 00454 if (*model_out == NULL) 00455 { 00456 assert((*model_out = malloc(sizeof(NB_freq_model))) != NULL); 00457 model = *model_out; 00458 model->marker_given_label = NULL; 00459 model->priors = NULL; 00460 } 00461 00462 if (fscanf(fp, "%u %u %u\n", &(model->num_markers), 00463 &(model->min_marker_val), &(model->num_marker_vals)) != 3) 00464 { 00465 snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file"); 00466 return JWSC_EARG(str); 00467 } 00468 00469 if ((e = read_vector_with_header_fp_u32(&(sparse), fp)) != NULL) 00470 { 00471 snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file"); 00472 return JWSC_EARG(str); 00473 } 00474 00475 if ((e = read_matrix_with_header_fp_d(&(sparse_marker_given_label), fp)) 00476 != NULL) 00477 { 00478 snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file"); 00479 return JWSC_EARG(str); 00480 } 00481 00482 if ((e = read_vector_with_header_fp_d(&(model->priors), fp)) != NULL) 00483 { 00484 snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file"); 00485 return JWSC_EARG(str); 00486 } 00487 00488 if (fclose(fp) != 0) 00489 { 00490 snprintf(str, slen, "%s: %s", fname, strerror(errno)); 00491 return JWSC_EARG(str); 00492 } 00493 00494 num_labels = sparse->num_elts; 00495 num_cols = sparse_marker_given_label->num_cols; 00496 00497 create_zero_matrix_d(&(model->marker_given_label), num_labels, num_cols); 00498 00499 s = 0; 00500 for (label = 0; label < num_labels; label++) 00501 { 00502 if (sparse->elts[ label ]) 00503 { 00504 copy_matrix_block_into_matrix_d(model->marker_given_label, 00505 label, 0, sparse_marker_given_label, s++, 0, 1, num_cols); 00506 } 00507 } 00508 00509 free_vector_u32(sparse); 00510 free_matrix_d(sparse_marker_given_label); 00511 00512 return NULL; 00513 } 00514 00515 00522 Error* write_nb_freq_model(NB_freq_model* model, const char* fname) 00523 { 00524 FILE* fp; 00525 uint32_t label; 00526 uint32_t num_labels; 00527 uint32_t i; 00528 uint32_t num_cols; 00529 uint32_t s; 00530 int slen = 256; 00531 char str[slen]; 00532 00533 Vector_u32* sparse = NULL; 00534 Matrix_d* sparse_marker_given_label = NULL; 00535 00536 num_labels = model->marker_given_label->num_rows; 00537 num_cols = model->marker_given_label->num_cols; 00538 00539 create_zero_vector_u32(&sparse, num_labels); 00540 00541 s = 0; 00542 for (label = 0; label < num_labels; label++) 00543 { 00544 for (i = 0; i < num_cols; i++) 00545 { 00546 if (model->marker_given_label->elts[ label ][ i ] > 0) 00547 { 00548 sparse->elts[ label ] = 1; 00549 s++; 00550 break; 00551 } 00552 } 00553 } 00554 00555 create_matrix_d(&sparse_marker_given_label, s, num_cols); 00556 00557 s = 0; 00558 for (label = 0; label < num_labels; label++) 00559 { 00560 if (sparse->elts[ label ]) 00561 { 00562 copy_matrix_block_into_matrix_d(sparse_marker_given_label, 00563 s++, 0, model->marker_given_label, label, 0, 1, num_cols); 00564 } 00565 } 00566 00567 if ((fp = fopen(fname, "w")) == NULL) 00568 { 00569 snprintf(str, slen, "%s: %s", fname, strerror(errno)); 00570 return JWSC_EARG(str); 00571 } 00572 00573 fprintf(fp, "%u %u %u\n", model->num_markers, model->min_marker_val, 00574 model->num_marker_vals); 00575 00576 write_vector_with_header_fp_u32(sparse, fp); 00577 write_matrix_with_header_fp_d(sparse_marker_given_label, fp); 00578 write_vector_with_header_fp_d(model->priors, fp); 00579 00580 free_vector_u32(sparse); 00581 free_matrix_d(sparse_marker_given_label); 00582 00583 if (fclose(fp) != 0) 00584 { 00585 snprintf(str, slen, "%s: %s", fname, strerror(errno)); 00586 return JWSC_EARG(str); 00587 } 00588 00589 return NULL; 00590 } 00591 00592 00594 void free_nb_freq_model(NB_freq_model* model) 00595 { 00596 if (!model) 00597 return; 00598 00599 free_matrix_d(model->marker_given_label); 00600 free_vector_d(model->priors); 00601 free(model); 00602 } 00603 00604 00606 static Error* read_nb_freq_xml_doc 00607 ( 00608 xmlDoc** xml_doc_out, 00609 const char* xml_fname, 00610 const char* dtd_fname 00611 ) 00612 { 00613 xmlParserCtxt* xml_parse_ctxt; 00614 xmlValidCtxt* xml_valid_ctxt; 00615 xmlDtd* xml_dtd; 00616 int slen = 256; 00617 char str[slen]; 00618 00619 assert(xml_parse_ctxt = xmlNewParserCtxt()); 00620 00621 if (!(*xml_doc_out = xmlCtxtReadFile(xml_parse_ctxt, xml_fname, NULL, 0))) 00622 { 00623 snprintf(str, slen, "%s: %s", xml_fname, "Could not parse file"); 00624 return JWSC_EARG(str); 00625 } 00626 00627 xmlFreeParserCtxt(xml_parse_ctxt); 00628 00629 if (dtd_fname) 00630 { 00631 assert(xml_valid_ctxt = xmlNewValidCtxt()); 00632 00633 if (!(xml_dtd = xmlParseDTD(NULL, (xmlChar*)dtd_fname))) 00634 { 00635 snprintf(str, slen, "%s: %s", dtd_fname, "Could not parse DTD"); 00636 return JWSC_EARG(str); 00637 } 00638 00639 if (!xmlValidateDtd(xml_valid_ctxt, *xml_doc_out, xml_dtd)) 00640 { 00641 snprintf(str, slen, "%s: %s", xml_fname, "XML file not valid"); 00642 return JWSC_EARG(str); 00643 } 00644 00645 xmlFreeValidCtxt(xml_valid_ctxt); 00646 xmlFreeDtd(xml_dtd); 00647 } 00648 00649 return NULL; 00650 } 00651 00652 00659 static Error* create_nb_freq_model_tree_from_xml_node 00660 ( 00661 NB_freq_model_tree** tree_out, 00662 const NB_freq_model_tree* parent, 00663 uint32_t parent_label, 00664 xmlNode* xml_node 00665 ) 00666 { 00667 NB_freq_model_tree* tree; 00668 uint32_t i, j; 00669 uint32_t label; 00670 uint32_t altlabel; 00671 uint32_t num_altlabels; 00672 double p; 00673 xmlAttr* attr; 00674 xmlNode* it; 00675 xmlNode* itt; 00676 Error* err; 00677 const char* fname; 00678 00679 if (*tree_out) 00680 { 00681 free_nb_freq_model_tree(*tree_out); 00682 } 00683 00684 assert(*tree_out = malloc(sizeof(NB_freq_model_tree))); 00685 tree = *tree_out; 00686 00687 tree->parent = parent; 00688 tree->parent_label = parent_label; 00689 tree->num_groups = 0; 00690 tree->subtrees = NULL; 00691 tree->labels = NULL; 00692 tree->altlabels = NULL; 00693 tree->priors = NULL; 00694 tree->model = NULL; 00695 tree->model_fname = NULL; 00696 00697 assert(XMLStrEqual(xml_node->name, "model")); 00698 00699 for (it = xml_node->children; it; it = it->next) 00700 { 00701 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "group")) 00702 { 00703 tree->num_groups++; 00704 } 00705 } 00706 00707 assert(tree->subtrees = calloc(tree->num_groups, sizeof(void*))); 00708 assert(tree->altlabels = calloc(tree->num_groups, sizeof(void*))); 00709 create_vector_u32(&(tree->labels), tree->num_groups); 00710 00711 for (attr = xml_node->properties; attr; attr = attr->next) 00712 { 00713 if (XMLStrEqual(attr->name, "priors") && 00714 XMLStrEqual(attr->children->content, "true")) 00715 { 00716 create_init_vector_d(&(tree->priors), get_num_haplo_groups(), 1.0); 00717 break; 00718 } 00719 } 00720 00721 i = 0; 00722 for (it = xml_node->children; it; it = it->next) 00723 { 00724 j = 0; 00725 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "file")) 00726 { 00727 fname = (const char*) it->children->content; 00728 tree->model_fname = malloc((strlen(fname)+1) * sizeof(char)); 00729 strcpy(tree->model_fname, fname); 00730 } 00731 else if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "group")) 00732 { 00733 num_altlabels = 0; 00734 for (itt = it->children; itt; itt = itt->next) 00735 { 00736 if (itt->type == XML_ELEMENT_NODE && 00737 XMLStrEqual(itt->name, "altlabel")) 00738 { 00739 num_altlabels++; 00740 } 00741 } 00742 if (num_altlabels) 00743 { 00744 create_vector_u32(&(tree->altlabels[ i ]), num_altlabels); 00745 } 00746 00747 for (itt = it->children; itt; itt = itt->next) 00748 { 00749 if (itt->type == XML_ELEMENT_NODE && 00750 XMLStrEqual(itt->name, "label")) 00751 { 00752 if ((err = lookup_haplo_group_index_from_label(&label, 00753 (const char*) itt->children->content))) 00754 { 00755 return err; 00756 } 00757 tree->labels->elts[ i ] = label; 00758 } 00759 else if (itt->type == XML_ELEMENT_NODE && 00760 XMLStrEqual(itt->name, "altlabel")) 00761 { 00762 if ((err = lookup_haplo_group_index_from_label(&altlabel, 00763 (const char*) itt->children->content))) 00764 { 00765 return err; 00766 } 00767 tree->altlabels[ i ]->elts[ j++ ] = altlabel; 00768 } 00769 else if (tree->priors && itt->type == XML_ELEMENT_NODE && 00770 XMLStrEqual(itt->name, "prior")) 00771 { 00772 if (sscanf((char*)itt->children->content, "%lf", &p) != 1 || 00773 p < 0 || p > 1) 00774 { 00775 return JWSC_EARG("Invalid prior"); 00776 } 00777 tree->priors->elts[ i ] = p; 00778 } 00779 else if (itt->type == XML_ELEMENT_NODE && 00780 XMLStrEqual(itt->name, "model")) 00781 { 00782 if ((err = create_nb_freq_model_tree_from_xml_node( 00783 &(tree->subtrees[ i ]), tree, label, itt))) 00784 { 00785 return err; 00786 } 00787 } 00788 } 00789 i++; 00790 } 00791 } 00792 assert(i == tree->num_groups); 00793 00794 return NULL; 00795 } 00796 00797 00802 static Error* create_model_training_data 00803 ( 00804 Vector_u32** train_labels_out, 00805 Matrix_i32** train_markers_out, 00806 const Vector_u32* data_labels, 00807 const Matrix_i32* data_markers, 00808 const Vector_u32* model_labels, 00809 Vector_u32*const* model_altlabels 00810 ) 00811 { 00812 uint8_t b; 00813 uint32_t i, j, k; 00814 uint32_t n; 00815 Vector_u32* train_labels = NULL; 00816 Matrix_i32* train_markers = NULL; 00817 00818 copy_vector_u32(&train_labels, data_labels); 00819 copy_matrix_i32(&train_markers, data_markers); 00820 00821 n = 0; 00822 for (i = 0; i < data_labels->num_elts; i++) 00823 { 00824 for (j = 0; j < model_labels->num_elts; j++) 00825 { 00826 if (is_ancestor(data_labels->elts[ i ], model_labels->elts[ j ])) 00827 { 00828 train_labels->elts[ n ] = model_labels->elts[ j ]; 00829 copy_matrix_block_into_matrix_i32(train_markers, n, 0, 00830 data_markers, i, 0, 1, data_markers->num_cols); 00831 n++; 00832 break; 00833 } 00834 else if (model_altlabels[ j ]) 00835 { 00836 for (k = 0; k < model_altlabels[ j ]->num_elts; k++) 00837 { 00838 if ((b = is_ancestor(data_labels->elts[ i ], 00839 model_altlabels[ j ]->elts[ k ]))) 00840 { 00841 train_labels->elts[ n ] = model_labels->elts[ j ]; 00842 copy_matrix_block_into_matrix_i32(train_markers, n, 0, 00843 data_markers, i, 0, 1, data_markers->num_cols); 00844 n++; 00845 break; 00846 } 00847 } 00848 if (b) 00849 { 00850 break; 00851 } 00852 } 00853 } 00854 } 00855 00856 if (!n) 00857 { 00858 return JWSC_EARG("No data for model"); 00859 } 00860 00861 copy_vector_section_u32(train_labels_out, train_labels, 0, n); 00862 copy_matrix_block_i32(train_markers_out, train_markers, 0, 0, n, 00863 train_markers->num_cols); 00864 00865 free_vector_u32(train_labels); 00866 free_matrix_i32(train_markers); 00867 00868 return NULL; 00869 } 00870 00871 00873 static Error* train_nb_freq_model_node 00874 ( 00875 NB_freq_model_node* node, 00876 const Vector_u32* labels, 00877 const Matrix_i32* markers 00878 ) 00879 { 00880 uint32_t i; 00881 Vector_u32* node_labels = NULL; 00882 Matrix_i32* node_markers = NULL; 00883 Error* err = NULL; 00884 int slen = 256; 00885 char str[slen]; 00886 00887 if ((err = create_model_training_data(&node_labels, &node_markers, labels, 00888 markers, node->labels, node->altlabels))) 00889 { 00890 snprintf(str, slen, "%s: %s", node->model_fname, err->msg); 00891 return JWSC_EARG(str); 00892 } 00893 00894 train_nb_freq_model(&(node->model), node_labels, node_markers, 00895 node->priors); 00896 00897 for (i = 0; i < node->num_groups; i++) 00898 { 00899 if (node->subtrees[ i ]) 00900 { 00901 if ((err = train_nb_freq_model_node(node->subtrees[ i ], labels, 00902 markers))) 00903 { 00904 return err; 00905 } 00906 } 00907 } 00908 00909 free_vector_u32(node_labels); 00910 free_matrix_i32(node_markers); 00911 00912 return NULL; 00913 } 00914 00915 00923 Error* train_nb_freq_model_tree 00924 ( 00925 NB_freq_model_tree** tree_out, 00926 const Vector_u32* labels, 00927 const Matrix_i32* markers, 00928 const char* tree_xml_fname, 00929 const char* tree_dtd_fname 00930 ) 00931 { 00932 xmlDoc* xml_doc; 00933 xmlNode* xml_node; 00934 xmlNode* it; 00935 Error* err; 00936 int slen = 256; 00937 char str[slen]; 00938 00939 assert(tree_xml_fname); 00940 00941 if ((err = read_nb_freq_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname))) 00942 { 00943 return err; 00944 } 00945 00946 for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next) 00947 { 00948 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model")) 00949 { 00950 xml_node = it; 00951 break; 00952 } 00953 } 00954 00955 if ((err = create_nb_freq_model_tree_from_xml_node(tree_out, NULL, 0, 00956 xml_node))) 00957 { 00958 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 00959 return JWSC_EARG(str); 00960 } 00961 00962 if ((err = train_nb_freq_model_node(*tree_out, labels, markers))) 00963 { 00964 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 00965 return JWSC_EARG(str); 00966 } 00967 00968 return NULL; 00969 } 00970 00971 00973 static Error* recursively_predict_label_in_model_tree 00974 ( 00975 uint32_t* label_out, 00976 double* confidence_out, 00977 const NB_freq_model_tree* tree, 00978 const Vector_i32* markers_v, 00979 uint32_t order 00980 ) 00981 { 00982 uint32_t subtree_label; 00983 uint32_t i; 00984 double confidence; 00985 Error* err; 00986 00987 if ((err = predict_label_with_nb_freq_model(label_out, &confidence, 00988 markers_v, tree->model, order)) != NULL) 00989 { 00990 return err; 00991 } 00992 *confidence_out *= confidence; 00993 00994 for (i = 0; i < tree->num_groups; i++) 00995 { 00996 if (tree->subtrees[ i ]) 00997 { 00998 subtree_label = tree->subtrees[ i ]->parent_label; 00999 01000 if (subtree_label == *label_out) 01001 { 01002 if ((err = recursively_predict_label_in_model_tree(label_out, 01003 confidence_out, tree->subtrees[ i ], markers_v, 01004 order)) != NULL) 01005 { 01006 return err; 01007 } 01008 } 01009 } 01010 } 01011 01012 return NULL; 01013 } 01014 01015 01031 Error* predict_labels_with_nb_freq_model_tree 01032 ( 01033 Vector_u32** labels_out, 01034 Vector_d** confidence_out, 01035 const Matrix_i32* markers, 01036 const NB_freq_model_tree* tree, 01037 uint32_t order 01038 ) 01039 { 01040 uint32_t s, num_samples; 01041 uint32_t m, num_markers; 01042 01043 Vector_u32* labels; 01044 Vector_d* confidence; 01045 Vector_i32* markers_v; 01046 Error* err; 01047 01048 num_samples = markers->num_rows; 01049 num_markers = markers->num_cols; 01050 01051 create_vector_u32(labels_out, num_samples); 01052 labels = *labels_out; 01053 01054 create_init_vector_d(confidence_out, num_samples, 1.0); 01055 confidence = *confidence_out; 01056 01057 markers_v = NULL; 01058 create_vector_i32(&markers_v, num_markers); 01059 01060 for (s = 0; s < num_samples; s++) 01061 { 01062 for (m = 0; m < num_markers; m++) 01063 { 01064 markers_v->elts[ m ]= markers->elts[ s ][ m ]; 01065 } 01066 01067 if ((err = recursively_predict_label_in_model_tree( 01068 &(labels->elts[ s ]), 01069 &(confidence->elts[ s ]), tree, markers_v, order)) 01070 != NULL) 01071 { 01072 return err; 01073 } 01074 } 01075 01076 free_vector_i32(markers_v); 01077 01078 return NULL; 01079 } 01080 01081 01083 static Error* read_nb_freq_model_node 01084 ( 01085 NB_freq_model_node* node, 01086 const char* model_dirname 01087 ) 01088 { 01089 uint32_t i; 01090 char buf[1024] = {0}; 01091 Error* err; 01092 01093 snprintf(buf, 1024, "%s/%s", model_dirname, node->model_fname); 01094 if ((err = read_nb_freq_model(&(node->model), buf))) 01095 { 01096 return err; 01097 } 01098 01099 for (i = 0; i < node->num_groups; i++) 01100 { 01101 if (node->subtrees[ i ]) 01102 { 01103 if ((err = read_nb_freq_model_node(node->subtrees[ i ], 01104 model_dirname))) 01105 { 01106 return err; 01107 } 01108 } 01109 } 01110 01111 return NULL; 01112 } 01113 01114 01121 Error* read_nb_freq_model_tree 01122 ( 01123 NB_freq_model_tree** tree_out, 01124 const char* tree_xml_fname, 01125 const char* tree_dtd_fname, 01126 const char* model_dirname 01127 ) 01128 { 01129 xmlDoc* xml_doc; 01130 xmlNode* xml_node; 01131 xmlNode* it; 01132 Error* err; 01133 int slen = 256; 01134 char str[slen]; 01135 01136 assert(tree_xml_fname && model_dirname); 01137 01138 if ((err = read_nb_freq_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname))) 01139 { 01140 return err; 01141 } 01142 01143 for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next) 01144 { 01145 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model")) 01146 { 01147 xml_node = it; 01148 break; 01149 } 01150 } 01151 01152 if ((err = create_nb_freq_model_tree_from_xml_node(tree_out, NULL, 0, 01153 xml_node))) 01154 { 01155 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 01156 return JWSC_EARG(str); 01157 } 01158 01159 if ((err = read_nb_freq_model_node(*tree_out, model_dirname))) 01160 { 01161 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 01162 return JWSC_EARG(str); 01163 } 01164 01165 return NULL; 01166 } 01167 01168 01173 Error* write_nb_freq_model_tree 01174 ( 01175 const NB_freq_model_tree* tree, 01176 const char* model_dirname 01177 ) 01178 { 01179 uint32_t i; 01180 char buf[1024] = {0}; 01181 Error* err; 01182 int slen = 256; 01183 char str[slen]; 01184 01185 snprintf(buf, 1024, "%s/%s", model_dirname, tree->model_fname); 01186 if ((err = write_nb_freq_model(tree->model, buf))) 01187 { 01188 return err; 01189 } 01190 01191 for (i = 0; i < tree->num_groups; i++) 01192 { 01193 if (tree->subtrees[ i ]) 01194 { 01195 if ((err = write_nb_freq_model_tree(tree->subtrees[ i ], 01196 model_dirname))) 01197 { 01198 snprintf(str, slen, "%s: %s", tree->model_fname, err->msg); 01199 return JWSC_EARG(str); 01200 } 01201 } 01202 } 01203 01204 return NULL; 01205 } 01206 01207 01209 void free_nb_freq_model_tree(NB_freq_model_tree* tree) 01210 { 01211 uint32_t i; 01212 01213 if (!tree) 01214 return; 01215 01216 for (i = 0; i < tree->num_groups; i++) 01217 { 01218 free_nb_freq_model_tree(tree->subtrees[ i ]); 01219 free_vector_u32(tree->altlabels[ i ]); 01220 } 01221 01222 free(tree->subtrees); 01223 free(tree->altlabels); 01224 free_vector_u32(tree->labels); 01225 free_vector_d(tree->priors); 01226 free_nb_freq_model(tree->model); 01227 free(tree->model_fname); 01228 free(tree); 01229 }