Haplo Prediction
predict haplogroups
nb_gauss.c
Go to the documentation of this file.
00001 /*
00002  * This work is licensed under a Creative Commons 
00003  * Attribution-Noncommercial-Share Alike 3.0 United States License.
00004  * 
00005  *    http://creativecommons.org/licenses/by-nc-sa/3.0/us/
00006  * 
00007  * You are free:
00008  * 
00009  *    to Share - to copy, distribute, display, and perform the work
00010  *    to Remix - to make derivative works
00011  * 
00012  * Under the following conditions:
00013  * 
00014  *    Attribution. You must attribute the work in the manner specified by the
00015  *    author or licensor (but not in any way that suggests that they endorse you
00016  *    or your use of the work).
00017  * 
00018  *    Noncommercial. You may not use this work for commercial purposes.
00019  * 
00020  *    Share Alike. If you alter, transform, or build upon this work, you may
00021  *    distribute the resulting work only under the same or similar license to
00022  *    this one.
00023  * 
00024  * For any reuse or distribution, you must make clear to others the license
00025  * terms of this work. The best way to do this is by including this header.
00026  * 
00027  * Any of the above conditions can be waived if you get permission from the
00028  * copyright holder.
00029  * 
00030  * Apart from the remix rights granted under this license, nothing in this
00031  * license impairs or restricts the author's moral rights.
00032  */
00033 
00034 
00046 #include <config.h>
00047 
00048 #include <stdlib.h>
00049 #include <stdio.h>
00050 #include <inttypes.h>
00051 #include <assert.h>
00052 #include <string.h>
00053 #include <errno.h>
00054 #include <math.h>
00055 
00056 #include <libxml/tree.h>
00057 #include <libxml/parser.h>
00058 #include <libxml/valid.h>
00059 
00060 #ifdef HAPLO_HAVE_DMALLOC
00061 #include <dmalloc.h>
00062 #endif
00063 
00064 #include <jwsc/base/error.h>
00065 #include <jwsc/base/limits.h>
00066 #include <jwsc/base/file_io.h>
00067 #include <jwsc/vector/vector.h>
00068 #include <jwsc/vector/vector_io.h>
00069 #include <jwsc/vector/vector_math.h>
00070 #include <jwsc/matrix/matrix.h>
00071 #include <jwsc/matrix/matrix_io.h>
00072 #include <jwsc/prob/pdf.h>
00073 
00074 #include "xml.h"
00075 #include "haplo_groups.h"
00076 #include "nb_gauss.h"
00077 
00078 
00088 void train_nb_gauss_model
00089 (
00090     NB_gauss_model**  model_out,
00091     const Vector_u32* labels, 
00092     const Matrix_i32* markers,
00093     const Vector_d*   label_priors
00094 )
00095 {
00096     uint32_t  num_labels;
00097     uint32_t  label;
00098     uint32_t  num_samples;
00099     uint32_t  sample;
00100     uint32_t  num_markers;
00101     uint32_t  marker;
00102     double    marker_val;
00103 
00104     Matrix_d*   mu;
00105     Matrix_d*   sigma;
00106     Vector_d*   priors;
00107     Matrix_u32* marker_counts;
00108 
00109     assert(markers->num_rows == labels->num_elts);
00110 
00111     num_labels = get_num_haplo_groups();
00112 
00113     if (*model_out != NULL)
00114     {
00115         mu     = (*model_out)->mu;
00116         sigma  = (*model_out)->sigma;
00117         priors = (*model_out)->priors;
00118     }
00119     else
00120     {
00121         mu     = NULL;
00122         sigma  = NULL;
00123         priors = NULL;
00124     }
00125 
00126     num_markers = markers->num_cols;
00127     num_samples = markers->num_rows;
00128 
00129     create_zero_matrix_d(&mu, num_labels, num_markers);
00130     create_init_matrix_d(&sigma, num_labels, num_markers, 0.1);
00131 
00132     marker_counts = NULL;
00133     create_zero_matrix_u32(&marker_counts, num_labels, num_markers);
00134 
00135     for (sample = 0; sample < num_samples; sample++)
00136     {
00137         label = labels->elts[ sample ];
00138 
00139         for (marker = 0; marker < num_markers; marker++)
00140         {
00141             marker_val = markers->elts[ sample ][ marker ];
00142 
00143             if (marker_val > 0)
00144             {
00145                 marker_counts->elts[ label ][ marker ]++;
00146                 mu->elts[ label ][ marker ] += marker_val;
00147                 sigma->elts[ label ][ marker ] += marker_val*marker_val;
00148             }
00149         }
00150     }
00151     for (label = 0; label < num_labels; label++)
00152     {
00153         for (marker = 0; marker < num_markers; marker++)
00154         {
00155             if (marker_counts->elts[ label ][ marker ] > 0)
00156             {
00157                 mu->elts[ label ][ marker ] /= 
00158                     marker_counts->elts[ label ][ marker ];
00159 
00160                 sigma->elts[ label ][ marker ] /= 
00161                     marker_counts->elts[ label ][ marker ];
00162 
00163                 sigma->elts[ label ][ marker ] -= 
00164                     mu->elts[ label ][ marker ] * mu->elts[ label ][ marker ];
00165 
00166                 sigma->elts[ label ][ marker ] = 
00167                     sqrt(sigma->elts[ label ][ marker ]);
00168             }
00169         }
00170     }
00171 
00172     if (label_priors == NULL)
00173     {
00174         create_zero_vector_d(&priors, num_labels);
00175 
00176         for (sample = 0; sample < num_samples; sample++)
00177         {
00178             label = labels->elts[ sample ];
00179             priors->elts[ label ]++;
00180         }
00181         normalize_vector_sum_d(&priors, priors);
00182     }
00183     else
00184     {
00185         normalize_vector_sum_d(&priors, label_priors);
00186     }
00187 
00188     if (*model_out == NULL)
00189     {
00190         assert(*model_out = malloc(sizeof(NB_gauss_model)));
00191     }
00192     (*model_out)->num_markers  = num_markers;
00193     (*model_out)->mu           = mu;
00194     (*model_out)->sigma        = sigma;
00195     (*model_out)->priors       = priors;
00196 
00197     free_matrix_u32(marker_counts);
00198 }
00199 
00200 
00212 Error* predict_label_with_nb_gauss_model
00213 (
00214     uint32_t*             label_out,
00215     double*               confidence_out,
00216     const Vector_i32*     markers,
00217     const NB_gauss_model* model,
00218     uint32_t             order
00219 )
00220 {
00221     uint32_t num_markers;
00222     uint32_t marker;
00223     uint32_t num_labels;
00224     uint32_t label;
00225     uint32_t best_label;
00226     uint32_t i;
00227     double   marker_val;
00228     double   mu, sigma;
00229     double   log_map;
00230     double   log_ll;
00231     double   log_prior;
00232     double   log_posterior;
00233     double   posterior_sum;
00234 
00235     Vector_u32* best_labels = NULL;
00236     Vector_d* best_labels_conf = NULL;
00237 
00238     num_labels = get_num_haplo_groups();
00239     num_markers = markers->num_elts;
00240 
00241     log_map = log(JWSC_MIN_LOG_ARG);
00242     posterior_sum = 0;
00243 
00244     create_init_vector_u32(&best_labels, num_labels, 0);
00245     create_init_vector_d(&best_labels_conf, num_labels, 0);
00246 
00247     for (label = 0; label < num_labels; label++)
00248     {
00249         log_ll = 0;
00250 
00251         for (marker = 0; marker < num_markers; marker++)
00252         {
00253             marker_val = markers->elts[ marker ];
00254 
00255             /* Ignore non-positive marker values. */
00256             if (marker_val > 0)
00257             {
00258                 mu = model->mu->elts[ label ][ marker ];
00259                 sigma = model->sigma->elts[ label ][ marker ];
00260 
00261                 log_ll += log_gaussian_pdf_d(mu, sigma, marker_val);
00262             }
00263         }
00264 
00265         if (model->priors->elts[ label ] < JWSC_MIN_LOG_ARG)
00266         {
00267             log_prior = log(JWSC_MIN_LOG_ARG);
00268         }
00269         else
00270         {
00271             log_prior = log(model->priors->elts[ label ]);
00272         }
00273 
00274         log_posterior = log_ll + log_prior;
00275         posterior_sum += exp(log_posterior);
00276 
00277         if (log_posterior > log_map)
00278         {
00279             log_map = log_posterior;
00280             best_label = label;
00281             for (i = num_labels-1; i > 0; i--)
00282             {
00283                 best_labels->elts[ i ] = best_labels->elts[ i - 1];
00284                 best_labels_conf->elts[ i ] = best_labels_conf->elts[ i - 1];
00285             }
00286             best_labels->elts[ 0 ] = best_label;
00287             best_labels_conf->elts[ 0 ] = log_map;
00288         }
00289     }
00290 
00291     assert(order < best_labels->num_elts);
00292 
00293     *label_out = best_labels->elts[ order ];
00294 
00295     if (posterior_sum > 0)
00296     {
00297         assert(finite(*confidence_out = exp(best_labels_conf->elts[order]) / 
00298             (double)posterior_sum));
00299     }
00300     else
00301     {
00302         *confidence_out = 0;
00303     }
00304 
00305     free_vector_u32(best_labels);
00306     free_vector_d(best_labels_conf);
00307 
00308     return NULL;
00309 }
00310 
00311 
00325 Error* predict_labels_with_nb_gauss_model
00326 (
00327     Vector_u32**          labels_out,
00328     Vector_d**            confidence_out,
00329     const Matrix_i32*     markers,
00330     const NB_gauss_model* model
00331 )
00332 {
00333     uint32_t num_samples;
00334     uint32_t sample;
00335     uint32_t num_markers;
00336     uint32_t marker;
00337 
00338     Vector_u32* labels;
00339     Vector_d*   confidence;
00340     Vector_i32* markers_v;
00341     Error*      e;
00342 
00343     num_samples = markers->num_rows;
00344     num_markers = markers->num_cols;
00345 
00346     create_vector_u32(labels_out, num_samples);
00347     labels = *labels_out;
00348 
00349     create_vector_d(confidence_out, num_samples);
00350     confidence = *confidence_out;
00351 
00352     markers_v = NULL;
00353     create_vector_i32(&markers_v, num_markers);
00354 
00355     for (sample = 0; sample < num_samples; sample++)
00356     {
00357         for (marker = 0; marker < num_markers; marker++)
00358         {
00359             markers_v->elts[ marker ] = markers->elts[ sample ][marker];
00360         }
00361 
00362         if ((e = predict_label_with_nb_gauss_model(&(labels->elts[ sample ]), 
00363                 &(confidence->elts[ sample ]), markers_v, model, 0)) != NULL)
00364         {
00365             return e;
00366         }
00367     }
00368 
00369     free_vector_i32(markers_v);
00370 
00371     return NULL;
00372 }
00373 
00374 
00383 Error* read_nb_gauss_model(NB_gauss_model** model_out, const char* fname)
00384 {
00385     FILE*           fp;
00386     NB_gauss_model* model;
00387     Error*          e;
00388     int             slen = 256;
00389     char            str[slen];
00390 
00391     if ((fp = fopen(fname, "r")) == NULL)
00392     {
00393         snprintf(str, slen, "%s: %s", fname, strerror(errno));
00394         return JWSC_EIO(str);
00395     }
00396 
00397     if (*model_out == NULL)
00398     {
00399         assert((*model_out = malloc(sizeof(NB_gauss_model))) != NULL);
00400         model = *model_out;
00401         model->mu = NULL;
00402         model->sigma = NULL;
00403         model->priors = NULL;
00404     }
00405 
00406     if (fscanf(fp, "%u\n", &(model->num_markers)) != 1)
00407     {
00408         snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file");
00409         return JWSC_EIO(str);
00410     }
00411 
00412     if ((e = read_matrix_with_header_fp_d(&(model->mu), fp)) != NULL)
00413     {
00414         snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file");
00415         return JWSC_EIO(str);
00416     }
00417 
00418     if ((e = read_matrix_with_header_fp_d(&(model->sigma), fp)) != NULL)
00419     {
00420         snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file");
00421         return JWSC_EIO(str);
00422     }
00423 
00424     if ((e = read_vector_with_header_fp_d(&(model->priors), fp)) != NULL)
00425     {
00426         snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file");
00427         return JWSC_EIO(str);
00428     }
00429 
00430     if (fclose(fp) != 0)
00431     {
00432         snprintf(str, slen, "%s: %s", fname, strerror(errno));
00433         return JWSC_EIO(str);
00434     }
00435 
00436     return NULL;
00437 }
00438 
00439 
00446 Error* write_nb_gauss_model(NB_gauss_model* model, const char* fname)
00447 {
00448     FILE* fp;
00449     int   slen = 256;
00450     char  str[slen];
00451 
00452     if ((fp = fopen(fname, "w")) == NULL)
00453     {
00454         snprintf(str, slen, "%s: %s", fname, strerror(errno));
00455         return JWSC_EIO(str);
00456     }
00457 
00458     fprintf(fp, "%u\n", model->num_markers);
00459 
00460     write_matrix_with_header_fp_d(model->mu, fp);
00461     write_matrix_with_header_fp_d(model->sigma, fp);
00462     write_vector_with_header_fp_d(model->priors, fp);
00463 
00464     if (fclose(fp) != 0)
00465     {
00466         snprintf(str, slen, "%s: %s", fname, strerror(errno));
00467         return JWSC_EIO(str);
00468     }
00469 
00470     return NULL;
00471 }
00472 
00473 
00475 void free_nb_gauss_model(NB_gauss_model* model)
00476 {
00477     if (!model)
00478         return;
00479 
00480     free_matrix_d(model->mu);
00481     free_matrix_d(model->sigma);
00482     free_vector_d(model->priors);
00483     free(model);
00484 }
00485 
00486 
00488 static Error* read_nb_gauss_xml_doc
00489 (
00490     xmlDoc**    xml_doc_out,
00491     const char* xml_fname,
00492     const char* dtd_fname
00493 )
00494 {
00495     xmlParserCtxt* xml_parse_ctxt;
00496     xmlValidCtxt*  xml_valid_ctxt;
00497     xmlDtd*        xml_dtd;
00498     int            slen = 256;
00499     char           str[slen];
00500 
00501     assert(xml_parse_ctxt = xmlNewParserCtxt());
00502 
00503     if (!(*xml_doc_out = xmlCtxtReadFile(xml_parse_ctxt, xml_fname, NULL, 0)))
00504     {
00505         snprintf(str, slen, "%s: %s", xml_fname, "Could not parse file");
00506         return JWSC_EIO(str);
00507     } 
00508 
00509     xmlFreeParserCtxt(xml_parse_ctxt);
00510 
00511     if (dtd_fname)
00512     {
00513         assert(xml_valid_ctxt = xmlNewValidCtxt());
00514 
00515         if (!(xml_dtd = xmlParseDTD(NULL, (xmlChar*)dtd_fname)))
00516         {
00517             snprintf(str, slen, "%s: %s", dtd_fname, "Could not parse DTD");
00518             return JWSC_EIO(str);
00519         }
00520 
00521         if (!xmlValidateDtd(xml_valid_ctxt, *xml_doc_out, xml_dtd))
00522         {
00523             snprintf(str, slen, "%s: %s", xml_fname, "XML file not valid");
00524             return JWSC_EIO(str);
00525         }
00526 
00527         xmlFreeValidCtxt(xml_valid_ctxt);
00528         xmlFreeDtd(xml_dtd);
00529     }
00530 
00531     return NULL;
00532 }
00533 
00534 
00541 static Error* create_nb_gauss_model_tree_from_xml_node
00542 (
00543      NB_gauss_model_tree**      tree_out, 
00544      const NB_gauss_model_tree* parent,
00545      uint32_t                   parent_label,
00546      xmlNode*                   xml_node
00547 )
00548 {
00549     NB_gauss_model_tree* tree;
00550     uint32_t i, j;
00551     uint32_t label;
00552     uint32_t altlabel;
00553     uint32_t num_altlabels;
00554     double   p;
00555     xmlAttr* attr;
00556     xmlNode* it;
00557     xmlNode* itt;
00558     Error*   err;
00559     const char* fname;
00560 
00561     if (*tree_out)
00562     {
00563         free_nb_gauss_model_tree(*tree_out);
00564     }
00565 
00566     assert(*tree_out = malloc(sizeof(NB_gauss_model_tree)));
00567     tree = *tree_out;
00568 
00569     tree->parent       = parent;
00570     tree->parent_label = parent_label;
00571     tree->num_groups   = 0;
00572     tree->subtrees     = NULL;
00573     tree->labels       = NULL;
00574     tree->altlabels    = NULL;
00575     tree->priors       = NULL;
00576     tree->model        = NULL;
00577     tree->model_fname  = NULL;
00578 
00579     assert(XMLStrEqual(xml_node->name, "model"));
00580 
00581     for (it = xml_node->children; it; it = it->next)
00582     {
00583         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "group"))
00584         {
00585             tree->num_groups++;
00586         }
00587     }
00588 
00589     assert(tree->subtrees = calloc(tree->num_groups, sizeof(void*)));
00590     assert(tree->altlabels = calloc(tree->num_groups, sizeof(void*)));
00591     create_vector_u32(&(tree->labels), tree->num_groups);
00592 
00593     for (attr = xml_node->properties; attr; attr = attr->next)
00594     {
00595         if (XMLStrEqual(attr->name, "priors") && 
00596                 XMLStrEqual(attr->children->content, "true"))
00597         {
00598             create_init_vector_d(&(tree->priors), get_num_haplo_groups(), 1.0);
00599             break;
00600         }
00601     }
00602 
00603     i = 0;
00604     for (it = xml_node->children; it; it = it->next)
00605     {
00606         j = 0;
00607         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "file"))
00608         {
00609             fname = (const char*) it->children->content;
00610             tree->model_fname = malloc((strlen(fname)+1) * sizeof(char));
00611             strcpy(tree->model_fname, fname);
00612         }
00613         else if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "group"))
00614         {
00615             num_altlabels = 0;
00616             for (itt = it->children; itt; itt = itt->next)
00617             {
00618                 if (itt->type == XML_ELEMENT_NODE &&
00619                         XMLStrEqual(itt->name, "altlabel"))
00620                 {
00621                     num_altlabels++;
00622                 }
00623             }
00624             if (num_altlabels)
00625             {
00626                 create_vector_u32(&(tree->altlabels[ i ]), num_altlabels);
00627             }
00628 
00629             for (itt = it->children; itt; itt = itt->next)
00630             {
00631                 if (itt->type == XML_ELEMENT_NODE &&
00632                         XMLStrEqual(itt->name, "label"))
00633                 {
00634                     if ((err = lookup_haplo_group_index_from_label(&label, 
00635                                     (const char*) itt->children->content)))
00636                     {
00637                         return err;
00638                     }
00639                     tree->labels->elts[ i ] = label;
00640                 }
00641                 else if (itt->type == XML_ELEMENT_NODE &&
00642                         XMLStrEqual(itt->name, "altlabel"))
00643                 {
00644                     if ((err = lookup_haplo_group_index_from_label(&altlabel, 
00645                                     (const char*) itt->children->content)))
00646                     {
00647                         return err;
00648                     }
00649                     tree->altlabels[ i ]->elts[ j++ ] = altlabel;
00650                 }
00651                 else if (tree->priors && itt->type == XML_ELEMENT_NODE &&
00652                         XMLStrEqual(itt->name, "prior"))
00653                 {
00654                     if (sscanf((char*)itt->children->content, "%lf", &p) != 1 ||
00655                             p < 0 || p > 1)
00656                     {
00657                         return JWSC_EARG("Invalid prior");
00658                     }
00659                     tree->priors->elts[ i ] = p;
00660                 }
00661                 else if (itt->type == XML_ELEMENT_NODE &&
00662                         XMLStrEqual(itt->name, "model"))
00663                 {
00664                     if ((err = create_nb_gauss_model_tree_from_xml_node(
00665                                  &(tree->subtrees[ i ]), tree, label, itt)))
00666                     {
00667                         return err;
00668                     }
00669                 }
00670             }
00671             i++;
00672         }
00673     }
00674     assert(i == tree->num_groups);
00675 
00676     return NULL;
00677 }
00678 
00679 
00684 static Error* create_model_training_data
00685 (
00686     Vector_u32**      train_labels_out,
00687     Matrix_i32**      train_markers_out,
00688     const Vector_u32* data_labels,
00689     const Matrix_i32* data_markers,
00690     const Vector_u32* model_labels,
00691     Vector_u32*const* model_altlabels
00692 )
00693 {
00694     uint8_t     b;
00695     uint32_t    i, j, k;
00696     uint32_t    n;
00697     Vector_u32* train_labels = NULL;
00698     Matrix_i32* train_markers = NULL;
00699 
00700     copy_vector_u32(&train_labels, data_labels);
00701     copy_matrix_i32(&train_markers, data_markers);
00702 
00703     n = 0;
00704     for (i = 0; i < data_labels->num_elts; i++)
00705     {
00706         for (j = 0; j < model_labels->num_elts; j++)
00707         {
00708             if (is_ancestor(data_labels->elts[ i ], model_labels->elts[ j ]))
00709             {
00710                 train_labels->elts[ n ] = model_labels->elts[ j ];
00711                 copy_matrix_block_into_matrix_i32(train_markers, n, 0,
00712                         data_markers, i, 0, 1, data_markers->num_cols);
00713                 n++;
00714                 break;
00715             }
00716             else if (model_altlabels[ j ])
00717             {
00718                 for (k = 0; k < model_altlabels[ j ]->num_elts; k++)
00719                 {
00720                     if ((b = is_ancestor(data_labels->elts[ i ], 
00721                                 model_altlabels[ j ]->elts[ k ])))
00722                     {
00723                         train_labels->elts[ n ] = model_labels->elts[ j ];
00724                         copy_matrix_block_into_matrix_i32(train_markers, n, 0,
00725                                 data_markers, i, 0, 1, data_markers->num_cols);
00726                         n++;
00727                         break;
00728                     }
00729                 }
00730                 if (b)
00731                 {
00732                     break;
00733                 }
00734             }
00735         }
00736     }
00737 
00738     if (!n)
00739     {
00740         return JWSC_EARG("No data for model");
00741     }
00742 
00743     copy_vector_section_u32(train_labels_out, train_labels, 0, n);
00744     copy_matrix_block_i32(train_markers_out, train_markers, 0, 0, n, 
00745             train_markers->num_cols);
00746 
00747     free_vector_u32(train_labels);
00748     free_matrix_i32(train_markers);
00749 
00750     return NULL;
00751 }
00752 
00753 
00755 static Error* train_nb_gauss_model_node
00756 (
00757     NB_gauss_model_node* node,
00758     const Vector_u32*    labels, 
00759     const Matrix_i32*    markers
00760 )
00761 {
00762     uint32_t i;
00763     Vector_u32* node_labels  = NULL;
00764     Matrix_i32* node_markers = NULL;
00765     Error*  err;
00766     int     slen = 256;
00767     char    str[slen];
00768 
00769     if ((err = create_model_training_data(&node_labels, &node_markers, labels,
00770                     markers, node->labels, node->altlabels)))
00771     {
00772         snprintf(str, slen, "%s: %s", node->model_fname, err->msg);
00773         return JWSC_EARG(str);
00774     }
00775 
00776     train_nb_gauss_model(&(node->model), node_labels, node_markers, 
00777             node->priors);
00778 
00779     for (i = 0; i < node->num_groups; i++)
00780     {
00781         if (node->subtrees[ i ])
00782         {
00783             if ((err = train_nb_gauss_model_node(node->subtrees[ i ], labels,
00784                         markers)))
00785             {
00786                 return err;
00787             }
00788         }
00789     }
00790 
00791     free_vector_u32(node_labels);
00792     free_matrix_i32(node_markers);
00793 
00794     return NULL;
00795 }
00796 
00797 
00805 Error* train_nb_gauss_model_tree
00806 (
00807     NB_gauss_model_tree** tree_out,
00808     const Vector_u32*     labels, 
00809     const Matrix_i32*     markers,
00810     const char*           tree_xml_fname,
00811     const char*           tree_dtd_fname
00812 )
00813 {
00814     xmlDoc*  xml_doc;
00815     xmlNode* xml_node;
00816     xmlNode* it;
00817     Error*   err;
00818     int      slen = 256;
00819     char     str[slen];
00820 
00821     assert(tree_xml_fname);
00822 
00823     if ((err = read_nb_gauss_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname)))
00824     {
00825         return err;
00826     }
00827 
00828     for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next)
00829     {
00830         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model"))
00831         {
00832             xml_node = it;
00833             break;
00834         }
00835     }
00836 
00837     if ((err = create_nb_gauss_model_tree_from_xml_node(tree_out, NULL, 0, 
00838                     xml_node)))
00839     {
00840         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
00841         return JWSC_EARG(str);
00842     }
00843 
00844     if ((err = train_nb_gauss_model_node(*tree_out, labels, markers)))
00845     {
00846         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
00847         return JWSC_EARG(str);
00848     }
00849 
00850     return NULL;
00851 }
00852 
00853 
00855 static Error* recursively_predict_label_in_model_tree
00856 (
00857     uint32_t*                  label_out,
00858     double*                    confidence_out,
00859     const NB_gauss_model_tree* tree,
00860     const Vector_i32*          markers_v,
00861     uint32_t                   order
00862 )
00863 {
00864     uint32_t subtree_label;
00865     uint32_t i;
00866     double   confidence;
00867     Error*   err;
00868 
00869     if ((err = predict_label_with_nb_gauss_model(label_out, &confidence, 
00870                     markers_v, tree->model, order)) != NULL)
00871     {
00872         return err;
00873     }
00874     *confidence_out *= confidence;
00875 
00876     for (i = 0; i < tree->num_groups; i++)
00877     {
00878         if (tree->subtrees[ i ])
00879         {
00880             subtree_label = tree->subtrees[ i ]->parent_label;
00881 
00882             if (subtree_label == *label_out)
00883             {
00884                 if ((err = recursively_predict_label_in_model_tree(label_out,
00885                                 confidence_out, tree->subtrees[ i ], markers_v,
00886                                 order)) != NULL)
00887                 {
00888                     return err;
00889                 }
00890             }
00891         }
00892     }
00893 
00894     return NULL;
00895 }
00896 
00897 
00913 Error* predict_labels_with_nb_gauss_model_tree
00914 (
00915     Vector_u32**              labels_out,
00916     Vector_d**                confidence_out,
00917     const Matrix_i32*         markers,
00918     const NB_gauss_model_tree* tree,
00919     uint32_t                  order
00920 )
00921 {
00922     uint32_t s, num_samples;
00923     uint32_t m, num_markers;
00924 
00925     Vector_u32* labels;
00926     Vector_d*   confidence;
00927     Vector_i32* markers_v;
00928     Error*      err;
00929 
00930     num_samples = markers->num_rows;
00931     num_markers = markers->num_cols;
00932 
00933     create_vector_u32(labels_out, num_samples);
00934     labels = *labels_out;
00935 
00936     create_init_vector_d(confidence_out, num_samples, 1.0);
00937     confidence = *confidence_out;
00938 
00939     markers_v = NULL;
00940     create_vector_i32(&markers_v, num_markers);
00941 
00942     for (s = 0; s < num_samples; s++)
00943     {
00944         for (m = 0; m < num_markers; m++)
00945         {
00946             markers_v->elts[ m ]= markers->elts[ s ][ m ];
00947         }
00948 
00949         if ((err = recursively_predict_label_in_model_tree(
00950                         &(labels->elts[ s ]), 
00951                         &(confidence->elts[ s ]), tree, markers_v, order)) 
00952                 != NULL)
00953         {
00954             return err;
00955         }
00956     }
00957 
00958     free_vector_i32(markers_v);
00959 
00960     return NULL;
00961 }
00962 
00963 
00965 static Error* read_nb_gauss_model_node
00966 (
00967      NB_gauss_model_node* node, 
00968      const char*          model_dirname
00969 )
00970 {
00971     uint32_t    i;
00972     char        buf[1024] = {0};
00973     Error*      err;
00974 
00975     snprintf(buf, 1024, "%s/%s", model_dirname, node->model_fname);
00976     if ((err = read_nb_gauss_model(&(node->model), buf)))
00977     {
00978         return err;
00979     }
00980 
00981     for (i = 0; i < node->num_groups; i++)
00982     {
00983         if (node->subtrees[ i ])
00984         {
00985             if ((err = read_nb_gauss_model_node(node->subtrees[ i ], 
00986                             model_dirname)))
00987             {
00988                 return err;
00989             }
00990         }
00991     }
00992 
00993     return NULL;
00994 }
00995 
00996 
01003 Error* read_nb_gauss_model_tree
01004 (
01005     NB_gauss_model_tree** tree_out,
01006     const char*           tree_xml_fname,
01007     const char*           tree_dtd_fname,
01008     const char*           model_dirname
01009 )
01010 {
01011     xmlDoc*  xml_doc;
01012     xmlNode* xml_node;
01013     xmlNode* it;
01014     Error*   err;
01015     int      slen = 256;
01016     char     str[slen];
01017 
01018     assert(tree_xml_fname && model_dirname);
01019 
01020     if ((err = read_nb_gauss_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname)))
01021     {
01022         return err;
01023     }
01024 
01025     for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next)
01026     {
01027         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model"))
01028         {
01029             xml_node = it;
01030             break;
01031         }
01032     }
01033 
01034     if ((err = create_nb_gauss_model_tree_from_xml_node(tree_out, NULL, 0,
01035                     xml_node)))
01036     {
01037         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
01038         return JWSC_EARG(str);
01039     }
01040 
01041     if ((err = read_nb_gauss_model_node(*tree_out, model_dirname)))
01042     {
01043         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
01044         return JWSC_EARG(str);
01045     }
01046 
01047     return NULL;
01048 }
01049 
01050 
01055 Error* write_nb_gauss_model_tree
01056 (
01057     const NB_gauss_model_tree* tree,
01058     const char*                model_dirname
01059 )
01060 {
01061     uint32_t i;
01062     char     buf[1024] = {0};
01063     Error*   err;
01064     int      slen = 256;
01065     char     str[slen];
01066 
01067     snprintf(buf, 1024, "%s/%s", model_dirname, tree->model_fname);
01068     if ((err = write_nb_gauss_model(tree->model, buf)))
01069     {
01070         return err;
01071     }
01072 
01073     for (i = 0; i < tree->num_groups; i++)
01074     {
01075         if (tree->subtrees[ i ])
01076         {
01077             if ((err = write_nb_gauss_model_tree(tree->subtrees[ i ], 
01078                             model_dirname)))
01079             {
01080                 snprintf(str, slen, "%s: %s", tree->model_fname, err->msg);
01081                 return JWSC_EIO(str);
01082             }
01083         }
01084     }
01085 
01086     return NULL;
01087 }
01088 
01089 
01091 void free_nb_gauss_model_tree(NB_gauss_model_tree* tree)
01092 {
01093     uint32_t i;
01094 
01095     if (!tree)
01096         return;
01097 
01098     for (i = 0; i < tree->num_groups; i++)
01099     {
01100         free_nb_gauss_model_tree(tree->subtrees[ i ]);
01101         free_vector_u32(tree->altlabels[ i ]);
01102     }
01103 
01104     free(tree->subtrees);
01105     free(tree->altlabels);
01106     free_vector_u32(tree->labels);
01107     free_vector_d(tree->priors);
01108     free_nb_gauss_model(tree->model);
01109     free(tree->model_fname);
01110     free(tree);
01111 }