Haplo Prediction
predict haplogroups
nb_freq.c
Go to the documentation of this file.
00001 /*
00002  * This work is licensed under a Creative Commons 
00003  * Attribution-Noncommercial-Share Alike 3.0 United States License.
00004  * 
00005  *    http://creativecommons.org/licenses/by-nc-sa/3.0/us/
00006  * 
00007  * You are free:
00008  * 
00009  *    to Share - to copy, distribute, display, and perform the work
00010  *    to Remix - to make derivative works
00011  * 
00012  * Under the following conditions:
00013  * 
00014  *    Attribution. You must attribute the work in the manner specified by the
00015  *    author or licensor (but not in any way that suggests that they endorse you
00016  *    or your use of the work).
00017  * 
00018  *    Noncommercial. You may not use this work for commercial purposes.
00019  * 
00020  *    Share Alike. If you alter, transform, or build upon this work, you may
00021  *    distribute the resulting work only under the same or similar license to
00022  *    this one.
00023  * 
00024  * For any reuse or distribution, you must make clear to others the license
00025  * terms of this work. The best way to do this is by including this header.
00026  * 
00027  * Any of the above conditions can be waived if you get permission from the
00028  * copyright holder.
00029  * 
00030  * Apart from the remix rights granted under this license, nothing in this
00031  * license impairs or restricts the author's moral rights.
00032  */
00033 
00034 
00046 #include <config.h>
00047 
00048 #include <stdlib.h>
00049 #include <stdio.h>
00050 #include <inttypes.h>
00051 #include <assert.h>
00052 #include <string.h>
00053 #include <errno.h>
00054 #include <math.h>
00055 
00056 #include <libxml/tree.h>
00057 #include <libxml/parser.h>
00058 #include <libxml/valid.h>
00059 
00060 #ifdef HAPLO_HAVE_DMALLOC
00061 #include <dmalloc.h>
00062 #endif
00063 
00064 #include <jwsc/base/error.h>
00065 #include <jwsc/base/limits.h>
00066 #include <jwsc/base/file_io.h>
00067 #include <jwsc/vector/vector.h>
00068 #include <jwsc/vector/vector_io.h>
00069 #include <jwsc/vector/vector_math.h>
00070 #include <jwsc/matrix/matrix.h>
00071 #include <jwsc/matrix/matrix_io.h>
00072 
00073 #include "xml.h"
00074 #include "haplo_groups.h"
00075 #include "nb_freq.h"
00076 
00077 
00089 void train_nb_freq_model
00090 (
00091     NB_freq_model**   model_out,
00092     const Vector_u32* labels, 
00093     const Matrix_i32* markers,
00094     const Vector_d*   label_priors
00095 )
00096 {
00097     uint32_t  num_labels;
00098     uint32_t  num_samples;
00099     uint32_t  sample;
00100     uint32_t  marker;
00101     uint32_t  label;
00102     uint32_t  i;
00103     uint32_t  num_markers;
00104     uint32_t  num_marker_vals;
00105     uint32_t  min_marker_val;
00106     uint32_t  max_marker_val;
00107     int32_t   marker_val;
00108     double    sum;
00109 
00110     Matrix_d* marker_given_label;
00111     Vector_d* priors;
00112 
00113     assert(markers->num_rows == labels->num_elts);
00114 
00115     num_labels = get_num_haplo_groups();
00116 
00117     if (*model_out != NULL)
00118     {
00119         marker_given_label = (*model_out)->marker_given_label;
00120         priors = (*model_out)->priors;
00121     }
00122     else
00123     {
00124         marker_given_label = NULL;
00125         priors = NULL;
00126     }
00127 
00128     num_markers = markers->num_cols;
00129     num_samples = markers->num_rows;
00130 
00131     max_marker_val = 1;
00132     min_marker_val = 1;
00133 
00134     for (sample = 0; sample < num_samples; sample++)
00135     {
00136         for (marker = 0; marker < num_markers; marker++)
00137         {
00138             if (markers->elts[ sample ][ marker ] > (int32_t)min_marker_val)
00139             {
00140                 if (markers->elts[ sample ][ marker ] > (int32_t)max_marker_val)
00141                 {
00142                     max_marker_val = markers->elts[ sample ][ marker ];
00143                 }
00144             }
00145         }
00146     }
00147 
00148     num_marker_vals = max_marker_val - min_marker_val + 1;
00149 
00150     create_zero_matrix_d(&marker_given_label, num_labels, 
00151             num_marker_vals*num_markers);
00152 
00153     for (sample = 0; sample < num_samples; sample++)
00154     {
00155         label = labels->elts[ sample ];
00156 
00157         for (marker = 0; marker < num_markers; marker++)
00158         {
00159             marker_val = markers->elts[ sample ][ marker ];
00160 
00161             /* Ignore non-positive marker values. */
00162             if (marker_val > min_marker_val)
00163             {
00164                 i = num_marker_vals*marker + 
00165                     ((uint32_t)marker_val - min_marker_val);
00166 
00167                 marker_given_label->elts[ label ][ i ]++;
00168             }
00169         }
00170     }
00171 
00172     /* If a label doesn't have any representation in the data, leave the
00173      * entire row of marker_given_label zero for sparse reading and writing.
00174      * Otherwise, if a label does appear, set all it's "zero" counts to a
00175      * small value so that there is no zero probability, just a small value. */
00176     for (label = 0; label < num_labels; label++)
00177     {
00178         sum = 0;
00179 
00180         for (i = 0; i < marker_given_label->num_cols; i++)
00181         {
00182             sum += marker_given_label->elts[ label ][ i ];
00183         }
00184 
00185         if (sum)
00186         {
00187             for (i = 0; i < marker_given_label->num_cols; i++)
00188             {
00189                 if (marker_given_label->elts[ label ][ i ] == 0)
00190                 {
00191                     marker_given_label->elts[ label ][ i ] = 0.001;
00192                 }
00193             }
00194         }
00195     }
00196 
00197     for (label = 0; label < num_labels; label++)
00198     {
00199         for (marker = 0; marker < num_markers; marker++)
00200         {
00201             sum = 0;
00202 
00203             for (marker_val = min_marker_val; marker_val < num_marker_vals; 
00204                     marker_val++)
00205             {
00206                 i = num_marker_vals*marker + 
00207                     ((uint32_t)marker_val - min_marker_val); 
00208 
00209                 sum += marker_given_label->elts[ label ][ i ];
00210             }
00211 
00212             if (sum == 0)
00213                 break;
00214 
00215             sum = 1.0 / sum;
00216 
00217             for (marker_val = min_marker_val; marker_val < num_marker_vals; 
00218                     marker_val++)
00219             {
00220                 i = num_marker_vals*marker + 
00221                     ((uint32_t)marker_val - min_marker_val); 
00222 
00223                 marker_given_label->elts[ label ][ i ] *= sum;
00224             }
00225         }
00226     }
00227 
00228     if (label_priors == NULL)
00229     {
00230         create_zero_vector_d(&priors, num_labels);
00231 
00232         for (sample = 0; sample < num_samples; sample++)
00233         {
00234             label = labels->elts[ sample ];
00235             priors->elts[ label ]++;
00236         }
00237         normalize_vector_sum_d(&priors, priors);
00238     }
00239     else
00240     {
00241         normalize_vector_sum_d(&priors, label_priors);
00242     }
00243 
00244 
00245     if (*model_out == NULL)
00246     {
00247         assert(*model_out = malloc(sizeof(NB_freq_model)));
00248     }
00249     (*model_out)->num_markers        = num_markers;
00250     (*model_out)->min_marker_val     = min_marker_val;
00251     (*model_out)->num_marker_vals    = num_marker_vals;
00252     (*model_out)->marker_given_label = marker_given_label;
00253     (*model_out)->priors             = priors;
00254 }
00255 
00256 
00268 Error* predict_label_with_nb_freq_model
00269 (
00270     uint32_t*            label_out,
00271     double*              confidence_out,
00272     const Vector_i32*    markers,
00273     const NB_freq_model* model,
00274     uint32_t             order
00275 )
00276 {
00277     uint32_t num_markers;
00278     uint32_t marker;
00279     uint32_t num_labels;
00280     uint32_t label;
00281     uint32_t best_label;
00282     uint32_t i;
00283     int32_t  marker_val;
00284 
00285     double  map;
00286     double  likelihood;
00287     double  prior;
00288     double  posterior;
00289     double  posterior_sum;
00290 
00291     Vector_u32* best_labels = NULL;
00292     Vector_d* best_labels_conf = NULL;
00293 
00294     num_labels = get_num_haplo_groups();
00295     num_markers = markers->num_elts;
00296 
00297     map = 0;
00298     posterior_sum = 0;
00299 
00300     create_init_vector_u32(&best_labels, num_labels, 0);
00301     create_init_vector_d(&best_labels_conf, num_labels, 0);
00302 
00303     for (label = 0; label < num_labels; label++)
00304     {
00305         likelihood = 1.0;
00306 
00307         for (marker = 0; marker < num_markers; marker++)
00308         {
00309             marker_val = markers->elts[ marker ];
00310 
00311             /* Ignore non-positive or invalid marker values. */
00312             if (marker_val > 0 && marker_val <= 
00313                     (model->num_marker_vals + model->min_marker_val))
00314             {
00315                 i = model->num_marker_vals*marker + 
00316                     ((uint32_t)marker_val - model->min_marker_val);
00317 
00318                 likelihood *= model->marker_given_label->elts[ label ][ i ];
00319             }
00320         }
00321 
00322         prior = model->priors->elts[ label ];
00323 
00324         assert(finite(posterior = likelihood * prior));
00325         posterior_sum += posterior;
00326 
00327         if (posterior > map)
00328         {
00329             map = posterior;
00330             best_label = label;
00331             for (i = num_labels-1; i > 0; i--)
00332             {
00333                 best_labels->elts[ i ] = best_labels->elts[ i - 1];
00334                 best_labels_conf->elts[ i ] = best_labels_conf->elts[ i - 1];
00335             }
00336             best_labels->elts[ 0 ] = best_label;
00337             best_labels_conf->elts[ 0 ] = map;
00338         }
00339     }
00340 
00341     assert(order < best_labels->num_elts);
00342 
00343     *label_out = best_labels->elts[ order ];
00344 
00345     if (posterior_sum > 0)
00346     {
00347         assert(finite(*confidence_out = best_labels_conf->elts[order] / 
00348             (double)posterior_sum));
00349     }
00350     else
00351     {
00352         *confidence_out = 0;
00353     }
00354 
00355     free_vector_u32(best_labels);
00356     free_vector_d(best_labels_conf);
00357 
00358     return NULL;
00359 }
00360 
00361 
00375 Error* predict_labels_with_nb_freq_model
00376 (
00377     Vector_u32**         labels_out,
00378     Vector_d**           confidence_out,
00379     const Matrix_i32*    markers,
00380     const NB_freq_model* model
00381 )
00382 {
00383     uint32_t num_samples;
00384     uint32_t sample;
00385     uint32_t num_markers;
00386     uint32_t marker;
00387 
00388     Vector_u32* labels;
00389     Vector_d*   confidence;
00390     Vector_i32* markers_v;
00391     Error*      e;
00392 
00393     num_samples = markers->num_rows;
00394     num_markers = markers->num_cols;
00395 
00396     create_vector_u32(labels_out, num_samples);
00397     labels = *labels_out;
00398 
00399     create_vector_d(confidence_out, num_samples);
00400     confidence = *confidence_out;
00401 
00402     markers_v = NULL;
00403     create_vector_i32(&markers_v, num_markers);
00404 
00405     for (sample = 0; sample < num_samples; sample++)
00406     {
00407         for (marker = 0; marker < num_markers; marker++)
00408         {
00409             markers_v->elts[ marker ] = markers->elts[ sample ][marker];
00410         }
00411 
00412         if ((e = predict_label_with_nb_freq_model(&(labels->elts[ sample ]), 
00413                 &(confidence->elts[ sample ]), markers_v, model, 0)) != NULL)
00414         {
00415             return e;
00416         }
00417     }
00418 
00419     free_vector_i32(markers_v);
00420 
00421     return NULL;
00422 }
00423 
00424 
00433 Error* read_nb_freq_model(NB_freq_model** model_out, const char* fname)
00434 {
00435     FILE*    fp;
00436     uint32_t label;
00437     uint32_t num_labels;
00438     uint32_t num_cols;
00439     uint32_t s;
00440 
00441     Vector_u32*    sparse = NULL;
00442     Matrix_d*      sparse_marker_given_label = NULL;
00443     NB_freq_model* model;
00444     Error*         e;
00445     int            slen = 256;
00446     char           str[slen];
00447 
00448     if ((fp = fopen(fname, "r")) == NULL)
00449     {
00450         snprintf(str, slen, "%s: %s", fname, strerror(errno));
00451         return JWSC_EARG(str);
00452     }
00453 
00454     if (*model_out == NULL)
00455     {
00456         assert((*model_out = malloc(sizeof(NB_freq_model))) != NULL);
00457         model = *model_out;
00458         model->marker_given_label = NULL;
00459         model->priors = NULL;
00460     }
00461 
00462     if (fscanf(fp, "%u %u %u\n", &(model->num_markers), 
00463                     &(model->min_marker_val), &(model->num_marker_vals)) != 3)
00464     {
00465         snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file");
00466         return JWSC_EARG(str);
00467     }
00468 
00469     if ((e = read_vector_with_header_fp_u32(&(sparse), fp)) != NULL)
00470     {
00471         snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file");
00472         return JWSC_EARG(str);
00473     }
00474 
00475     if ((e = read_matrix_with_header_fp_d(&(sparse_marker_given_label), fp))
00476             != NULL)
00477     {
00478         snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file");
00479         return JWSC_EARG(str);
00480     }
00481 
00482     if ((e = read_vector_with_header_fp_d(&(model->priors), fp)) != NULL)
00483     {
00484         snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file");
00485         return JWSC_EARG(str);
00486     }
00487 
00488     if (fclose(fp) != 0)
00489     {
00490         snprintf(str, slen, "%s: %s", fname, strerror(errno));
00491         return JWSC_EARG(str);
00492     }
00493 
00494     num_labels = sparse->num_elts;
00495     num_cols   = sparse_marker_given_label->num_cols;
00496 
00497     create_zero_matrix_d(&(model->marker_given_label), num_labels, num_cols);
00498 
00499     s = 0;
00500     for (label = 0; label < num_labels; label++)
00501     {
00502         if (sparse->elts[ label ])
00503         {
00504             copy_matrix_block_into_matrix_d(model->marker_given_label,
00505                     label, 0, sparse_marker_given_label, s++, 0, 1, num_cols);
00506         }
00507     }
00508 
00509     free_vector_u32(sparse);
00510     free_matrix_d(sparse_marker_given_label);
00511 
00512     return NULL;
00513 }
00514 
00515 
00522 Error* write_nb_freq_model(NB_freq_model* model, const char* fname)
00523 {
00524     FILE* fp;
00525     uint32_t label;
00526     uint32_t num_labels;
00527     uint32_t i;
00528     uint32_t num_cols;
00529     uint32_t s;
00530     int  slen = 256;
00531     char str[slen];
00532 
00533     Vector_u32* sparse = NULL;
00534     Matrix_d*   sparse_marker_given_label = NULL;
00535 
00536     num_labels = model->marker_given_label->num_rows;
00537     num_cols   = model->marker_given_label->num_cols;
00538 
00539     create_zero_vector_u32(&sparse, num_labels);
00540 
00541     s = 0;
00542     for (label = 0; label < num_labels; label++)
00543     {
00544         for (i = 0; i < num_cols; i++)
00545         {
00546             if (model->marker_given_label->elts[ label ][ i ] > 0)
00547             {
00548                 sparse->elts[ label ] = 1;
00549                 s++;
00550                 break;
00551             }
00552         }
00553     }
00554 
00555     create_matrix_d(&sparse_marker_given_label, s, num_cols);
00556 
00557     s = 0;
00558     for (label = 0; label < num_labels; label++)
00559     {
00560         if (sparse->elts[ label ])
00561         {
00562             copy_matrix_block_into_matrix_d(sparse_marker_given_label,
00563                     s++, 0, model->marker_given_label, label, 0, 1, num_cols);
00564         }
00565     }
00566 
00567     if ((fp = fopen(fname, "w")) == NULL)
00568     {
00569         snprintf(str, slen, "%s: %s", fname, strerror(errno));
00570         return JWSC_EARG(str);
00571     }
00572 
00573     fprintf(fp, "%u %u %u\n", model->num_markers, model->min_marker_val,
00574             model->num_marker_vals);
00575 
00576     write_vector_with_header_fp_u32(sparse, fp);
00577     write_matrix_with_header_fp_d(sparse_marker_given_label, fp);
00578     write_vector_with_header_fp_d(model->priors, fp);
00579 
00580     free_vector_u32(sparse);
00581     free_matrix_d(sparse_marker_given_label);
00582 
00583     if (fclose(fp) != 0)
00584     {
00585         snprintf(str, slen, "%s: %s", fname, strerror(errno));
00586         return JWSC_EARG(str);
00587     }
00588 
00589     return NULL;
00590 }
00591 
00592 
00594 void free_nb_freq_model(NB_freq_model* model)
00595 {
00596     if (!model)
00597         return;
00598 
00599     free_matrix_d(model->marker_given_label);
00600     free_vector_d(model->priors);
00601     free(model);
00602 }
00603 
00604 
00606 static Error* read_nb_freq_xml_doc
00607 (
00608     xmlDoc**    xml_doc_out,
00609     const char* xml_fname,
00610     const char* dtd_fname
00611 )
00612 {
00613     xmlParserCtxt* xml_parse_ctxt;
00614     xmlValidCtxt*  xml_valid_ctxt;
00615     xmlDtd*        xml_dtd;
00616     int            slen = 256;
00617     char           str[slen];
00618 
00619     assert(xml_parse_ctxt = xmlNewParserCtxt());
00620 
00621     if (!(*xml_doc_out = xmlCtxtReadFile(xml_parse_ctxt, xml_fname, NULL, 0)))
00622     {
00623         snprintf(str, slen, "%s: %s", xml_fname, "Could not parse file");
00624         return JWSC_EARG(str);
00625     } 
00626 
00627     xmlFreeParserCtxt(xml_parse_ctxt);
00628 
00629     if (dtd_fname)
00630     {
00631         assert(xml_valid_ctxt = xmlNewValidCtxt());
00632 
00633         if (!(xml_dtd = xmlParseDTD(NULL, (xmlChar*)dtd_fname)))
00634         {
00635             snprintf(str, slen, "%s: %s", dtd_fname, "Could not parse DTD");
00636             return JWSC_EARG(str);
00637         }
00638 
00639         if (!xmlValidateDtd(xml_valid_ctxt, *xml_doc_out, xml_dtd))
00640         {
00641             snprintf(str, slen, "%s: %s", xml_fname, "XML file not valid");
00642             return JWSC_EARG(str);
00643         }
00644 
00645         xmlFreeValidCtxt(xml_valid_ctxt);
00646         xmlFreeDtd(xml_dtd);
00647     }
00648 
00649     return NULL;
00650 }
00651 
00652 
00659 static Error* create_nb_freq_model_tree_from_xml_node
00660 (
00661      NB_freq_model_tree**      tree_out, 
00662      const NB_freq_model_tree* parent,
00663      uint32_t                  parent_label,
00664      xmlNode*                  xml_node
00665 )
00666 {
00667     NB_freq_model_tree* tree;
00668     uint32_t i, j;
00669     uint32_t label;
00670     uint32_t altlabel;
00671     uint32_t num_altlabels;
00672     double   p;
00673     xmlAttr* attr;
00674     xmlNode* it;
00675     xmlNode* itt;
00676     Error*   err;
00677     const char* fname;
00678 
00679     if (*tree_out)
00680     {
00681         free_nb_freq_model_tree(*tree_out);
00682     }
00683 
00684     assert(*tree_out = malloc(sizeof(NB_freq_model_tree)));
00685     tree = *tree_out;
00686 
00687     tree->parent       = parent;
00688     tree->parent_label = parent_label;
00689     tree->num_groups   = 0;
00690     tree->subtrees     = NULL;
00691     tree->labels       = NULL;
00692     tree->altlabels    = NULL;
00693     tree->priors       = NULL;
00694     tree->model        = NULL;
00695     tree->model_fname  = NULL;
00696 
00697     assert(XMLStrEqual(xml_node->name, "model"));
00698 
00699     for (it = xml_node->children; it; it = it->next)
00700     {
00701         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "group"))
00702         {
00703             tree->num_groups++;
00704         }
00705     }
00706 
00707     assert(tree->subtrees = calloc(tree->num_groups, sizeof(void*)));
00708     assert(tree->altlabels = calloc(tree->num_groups, sizeof(void*)));
00709     create_vector_u32(&(tree->labels), tree->num_groups);
00710 
00711     for (attr = xml_node->properties; attr; attr = attr->next)
00712     {
00713         if (XMLStrEqual(attr->name, "priors") && 
00714                 XMLStrEqual(attr->children->content, "true"))
00715         {
00716             create_init_vector_d(&(tree->priors), get_num_haplo_groups(), 1.0);
00717             break;
00718         }
00719     }
00720 
00721     i = 0;
00722     for (it = xml_node->children; it; it = it->next)
00723     {
00724         j = 0;
00725         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "file"))
00726         {
00727             fname = (const char*) it->children->content;
00728             tree->model_fname = malloc((strlen(fname)+1) * sizeof(char));
00729             strcpy(tree->model_fname, fname);
00730         }
00731         else if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "group"))
00732         {
00733             num_altlabels = 0;
00734             for (itt = it->children; itt; itt = itt->next)
00735             {
00736                 if (itt->type == XML_ELEMENT_NODE &&
00737                         XMLStrEqual(itt->name, "altlabel"))
00738                 {
00739                     num_altlabels++;
00740                 }
00741             }
00742             if (num_altlabels)
00743             {
00744                 create_vector_u32(&(tree->altlabels[ i ]), num_altlabels);
00745             }
00746 
00747             for (itt = it->children; itt; itt = itt->next)
00748             {
00749                 if (itt->type == XML_ELEMENT_NODE &&
00750                         XMLStrEqual(itt->name, "label"))
00751                 {
00752                     if ((err = lookup_haplo_group_index_from_label(&label, 
00753                                     (const char*) itt->children->content)))
00754                     {
00755                         return err;
00756                     }
00757                     tree->labels->elts[ i ] = label;
00758                 }
00759                 else if (itt->type == XML_ELEMENT_NODE &&
00760                         XMLStrEqual(itt->name, "altlabel"))
00761                 {
00762                     if ((err = lookup_haplo_group_index_from_label(&altlabel, 
00763                                     (const char*) itt->children->content)))
00764                     {
00765                         return err;
00766                     }
00767                     tree->altlabels[ i ]->elts[ j++ ] = altlabel;
00768                 }
00769                 else if (tree->priors && itt->type == XML_ELEMENT_NODE &&
00770                         XMLStrEqual(itt->name, "prior"))
00771                 {
00772                     if (sscanf((char*)itt->children->content, "%lf", &p) != 1 ||
00773                             p < 0 || p > 1)
00774                     {
00775                         return JWSC_EARG("Invalid prior");
00776                     }
00777                     tree->priors->elts[ i ] = p;
00778                 }
00779                 else if (itt->type == XML_ELEMENT_NODE &&
00780                         XMLStrEqual(itt->name, "model"))
00781                 {
00782                     if ((err = create_nb_freq_model_tree_from_xml_node(
00783                                  &(tree->subtrees[ i ]), tree, label, itt)))
00784                     {
00785                         return err;
00786                     }
00787                 }
00788             }
00789             i++;
00790         }
00791     }
00792     assert(i == tree->num_groups);
00793 
00794     return NULL;
00795 }
00796 
00797 
00802 static Error* create_model_training_data
00803 (
00804     Vector_u32**      train_labels_out,
00805     Matrix_i32**      train_markers_out,
00806     const Vector_u32* data_labels,
00807     const Matrix_i32* data_markers,
00808     const Vector_u32* model_labels,
00809     Vector_u32*const* model_altlabels
00810 )
00811 {
00812     uint8_t     b;
00813     uint32_t    i, j, k;
00814     uint32_t    n;
00815     Vector_u32* train_labels = NULL;
00816     Matrix_i32* train_markers = NULL;
00817 
00818     copy_vector_u32(&train_labels, data_labels);
00819     copy_matrix_i32(&train_markers, data_markers);
00820 
00821     n = 0;
00822     for (i = 0; i < data_labels->num_elts; i++)
00823     {
00824         for (j = 0; j < model_labels->num_elts; j++)
00825         {
00826             if (is_ancestor(data_labels->elts[ i ], model_labels->elts[ j ]))
00827             {
00828                 train_labels->elts[ n ] = model_labels->elts[ j ];
00829                 copy_matrix_block_into_matrix_i32(train_markers, n, 0,
00830                         data_markers, i, 0, 1, data_markers->num_cols);
00831                 n++;
00832                 break;
00833             }
00834             else if (model_altlabels[ j ])
00835             {
00836                 for (k = 0; k < model_altlabels[ j ]->num_elts; k++)
00837                 {
00838                     if ((b = is_ancestor(data_labels->elts[ i ], 
00839                                 model_altlabels[ j ]->elts[ k ])))
00840                     {
00841                         train_labels->elts[ n ] = model_labels->elts[ j ];
00842                         copy_matrix_block_into_matrix_i32(train_markers, n, 0,
00843                                 data_markers, i, 0, 1, data_markers->num_cols);
00844                         n++;
00845                         break;
00846                     }
00847                 }
00848                 if (b)
00849                 {
00850                     break;
00851                 }
00852             }
00853         }
00854     }
00855 
00856     if (!n)
00857     {
00858         return JWSC_EARG("No data for model");
00859     }
00860 
00861     copy_vector_section_u32(train_labels_out, train_labels, 0, n);
00862     copy_matrix_block_i32(train_markers_out, train_markers, 0, 0, n, 
00863             train_markers->num_cols);
00864 
00865     free_vector_u32(train_labels);
00866     free_matrix_i32(train_markers);
00867 
00868     return NULL;
00869 }
00870 
00871 
00873 static Error* train_nb_freq_model_node
00874 (
00875     NB_freq_model_node* node,
00876     const Vector_u32*   labels, 
00877     const Matrix_i32*   markers
00878 )
00879 {
00880     uint32_t i;
00881     Vector_u32* node_labels  = NULL;
00882     Matrix_i32* node_markers = NULL;
00883     Error* err = NULL;
00884     int  slen = 256;
00885     char str[slen];
00886 
00887     if ((err = create_model_training_data(&node_labels, &node_markers, labels,
00888                     markers, node->labels, node->altlabels)))
00889     {
00890         snprintf(str, slen, "%s: %s", node->model_fname, err->msg);
00891         return JWSC_EARG(str);
00892     }
00893 
00894     train_nb_freq_model(&(node->model), node_labels, node_markers, 
00895             node->priors);
00896 
00897     for (i = 0; i < node->num_groups; i++)
00898     {
00899         if (node->subtrees[ i ])
00900         {
00901             if ((err = train_nb_freq_model_node(node->subtrees[ i ], labels,
00902                         markers)))
00903             {
00904                 return err;
00905             }
00906         }
00907     }
00908 
00909     free_vector_u32(node_labels);
00910     free_matrix_i32(node_markers);
00911 
00912     return NULL;
00913 }
00914 
00915 
00923 Error* train_nb_freq_model_tree
00924 (
00925     NB_freq_model_tree** tree_out,
00926     const Vector_u32*    labels, 
00927     const Matrix_i32*    markers,
00928     const char*          tree_xml_fname,
00929     const char*          tree_dtd_fname
00930 )
00931 {
00932     xmlDoc*  xml_doc;
00933     xmlNode* xml_node;
00934     xmlNode* it;
00935     Error*   err;
00936     int      slen = 256;
00937     char     str[slen];
00938 
00939     assert(tree_xml_fname);
00940 
00941     if ((err = read_nb_freq_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname)))
00942     {
00943         return err;
00944     }
00945 
00946     for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next)
00947     {
00948         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model"))
00949         {
00950             xml_node = it;
00951             break;
00952         }
00953     }
00954 
00955     if ((err = create_nb_freq_model_tree_from_xml_node(tree_out, NULL, 0, 
00956                     xml_node)))
00957     {
00958         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
00959         return JWSC_EARG(str);
00960     }
00961 
00962     if ((err = train_nb_freq_model_node(*tree_out, labels, markers)))
00963     {
00964         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
00965         return JWSC_EARG(str);
00966     }
00967 
00968     return NULL;
00969 }
00970 
00971 
00973 static Error* recursively_predict_label_in_model_tree
00974 (
00975     uint32_t*                 label_out,
00976     double*                   confidence_out,
00977     const NB_freq_model_tree* tree,
00978     const Vector_i32*         markers_v,
00979     uint32_t                  order
00980 )
00981 {
00982     uint32_t subtree_label;
00983     uint32_t i;
00984     double   confidence;
00985     Error*   err;
00986 
00987     if ((err = predict_label_with_nb_freq_model(label_out, &confidence, 
00988                     markers_v, tree->model, order)) != NULL)
00989     {
00990         return err;
00991     }
00992     *confidence_out *= confidence;
00993 
00994     for (i = 0; i < tree->num_groups; i++)
00995     {
00996         if (tree->subtrees[ i ])
00997         {
00998             subtree_label = tree->subtrees[ i ]->parent_label;
00999 
01000             if (subtree_label == *label_out)
01001             {
01002                 if ((err = recursively_predict_label_in_model_tree(label_out,
01003                                 confidence_out, tree->subtrees[ i ], markers_v,
01004                                 order)) != NULL)
01005                 {
01006                     return err;
01007                 }
01008             }
01009         }
01010     }
01011 
01012     return NULL;
01013 }
01014 
01015 
01031 Error* predict_labels_with_nb_freq_model_tree
01032 (
01033     Vector_u32**              labels_out,
01034     Vector_d**                confidence_out,
01035     const Matrix_i32*         markers,
01036     const NB_freq_model_tree* tree,
01037     uint32_t                  order
01038 )
01039 {
01040     uint32_t s, num_samples;
01041     uint32_t m, num_markers;
01042 
01043     Vector_u32* labels;
01044     Vector_d*   confidence;
01045     Vector_i32* markers_v;
01046     Error*      err;
01047 
01048     num_samples = markers->num_rows;
01049     num_markers = markers->num_cols;
01050 
01051     create_vector_u32(labels_out, num_samples);
01052     labels = *labels_out;
01053 
01054     create_init_vector_d(confidence_out, num_samples, 1.0);
01055     confidence = *confidence_out;
01056 
01057     markers_v = NULL;
01058     create_vector_i32(&markers_v, num_markers);
01059 
01060     for (s = 0; s < num_samples; s++)
01061     {
01062         for (m = 0; m < num_markers; m++)
01063         {
01064             markers_v->elts[ m ]= markers->elts[ s ][ m ];
01065         }
01066 
01067         if ((err = recursively_predict_label_in_model_tree(
01068                         &(labels->elts[ s ]), 
01069                         &(confidence->elts[ s ]), tree, markers_v, order)) 
01070                 != NULL)
01071         {
01072             return err;
01073         }
01074     }
01075 
01076     free_vector_i32(markers_v);
01077 
01078     return NULL;
01079 }
01080 
01081 
01083 static Error* read_nb_freq_model_node
01084 (
01085      NB_freq_model_node* node, 
01086      const char*         model_dirname
01087 )
01088 {
01089     uint32_t    i;
01090     char        buf[1024] = {0};
01091     Error*      err;
01092 
01093     snprintf(buf, 1024, "%s/%s", model_dirname, node->model_fname);
01094     if ((err = read_nb_freq_model(&(node->model), buf)))
01095     {
01096         return err;
01097     }
01098 
01099     for (i = 0; i < node->num_groups; i++)
01100     {
01101         if (node->subtrees[ i ])
01102         {
01103             if ((err = read_nb_freq_model_node(node->subtrees[ i ], 
01104                             model_dirname)))
01105             {
01106                 return err;
01107             }
01108         }
01109     }
01110 
01111     return NULL;
01112 }
01113 
01114 
01121 Error* read_nb_freq_model_tree
01122 (
01123     NB_freq_model_tree** tree_out,
01124     const char*          tree_xml_fname,
01125     const char*          tree_dtd_fname,
01126     const char*          model_dirname
01127 )
01128 {
01129     xmlDoc*  xml_doc;
01130     xmlNode* xml_node;
01131     xmlNode* it;
01132     Error*   err;
01133     int      slen = 256;
01134     char     str[slen];
01135 
01136     assert(tree_xml_fname && model_dirname);
01137 
01138     if ((err = read_nb_freq_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname)))
01139     {
01140         return err;
01141     }
01142 
01143     for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next)
01144     {
01145         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model"))
01146         {
01147             xml_node = it;
01148             break;
01149         }
01150     }
01151 
01152     if ((err = create_nb_freq_model_tree_from_xml_node(tree_out, NULL, 0,
01153                     xml_node)))
01154     {
01155         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
01156         return JWSC_EARG(str);
01157     }
01158 
01159     if ((err = read_nb_freq_model_node(*tree_out, model_dirname)))
01160     {
01161         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
01162         return JWSC_EARG(str);
01163     }
01164 
01165     return NULL;
01166 }
01167 
01168 
01173 Error* write_nb_freq_model_tree
01174 (
01175     const NB_freq_model_tree* tree,
01176     const char*               model_dirname
01177 )
01178 {
01179     uint32_t i;
01180     char     buf[1024] = {0};
01181     Error*   err;
01182     int      slen = 256;
01183     char     str[slen];
01184 
01185     snprintf(buf, 1024, "%s/%s", model_dirname, tree->model_fname);
01186     if ((err = write_nb_freq_model(tree->model, buf)))
01187     {
01188         return err;
01189     }
01190 
01191     for (i = 0; i < tree->num_groups; i++)
01192     {
01193         if (tree->subtrees[ i ])
01194         {
01195             if ((err = write_nb_freq_model_tree(tree->subtrees[ i ], 
01196                             model_dirname)))
01197             {
01198                 snprintf(str, slen, "%s: %s", tree->model_fname, err->msg);
01199                 return JWSC_EARG(str);
01200             }
01201         }
01202     }
01203 
01204     return NULL;
01205 }
01206 
01207 
01209 void free_nb_freq_model_tree(NB_freq_model_tree* tree)
01210 {
01211     uint32_t i;
01212 
01213     if (!tree)
01214         return;
01215 
01216     for (i = 0; i < tree->num_groups; i++)
01217     {
01218         free_nb_freq_model_tree(tree->subtrees[ i ]);
01219         free_vector_u32(tree->altlabels[ i ]);
01220     }
01221 
01222     free(tree->subtrees);
01223     free(tree->altlabels);
01224     free_vector_u32(tree->labels);
01225     free_vector_d(tree->priors);
01226     free_nb_freq_model(tree->model);
01227     free(tree->model_fname);
01228     free(tree);
01229 }