Haplo Prediction
predict haplogroups
nb_gmm.c
Go to the documentation of this file.
00001 /*
00002  * This work is licensed under a Creative Commons 
00003  * Attribution-Noncommercial-Share Alike 3.0 United States License.
00004  * 
00005  *    http://creativecommons.org/licenses/by-nc-sa/3.0/us/
00006  * 
00007  * You are free:
00008  * 
00009  *    to Share - to copy, distribute, display, and perform the work
00010  *    to Remix - to make derivative works
00011  * 
00012  * Under the following conditions:
00013  * 
00014  *    Attribution. You must attribute the work in the manner specified by the
00015  *    author or licensor (but not in any way that suggests that they endorse you
00016  *    or your use of the work).
00017  * 
00018  *    Noncommercial. You may not use this work for commercial purposes.
00019  * 
00020  *    Share Alike. If you alter, transform, or build upon this work, you may
00021  *    distribute the resulting work only under the same or similar license to
00022  *    this one.
00023  * 
00024  * For any reuse or distribution, you must make clear to others the license
00025  * terms of this work. The best way to do this is by including this header.
00026  * 
00027  * Any of the above conditions can be waived if you get permission from the
00028  * copyright holder.
00029  * 
00030  * Apart from the remix rights granted under this license, nothing in this
00031  * license impairs or restricts the author's moral rights.
00032  */
00033 
00034 
00046 #include <config.h>
00047 
00048 #include <stdlib.h>
00049 #include <stdio.h>
00050 #include <inttypes.h>
00051 #include <assert.h>
00052 #include <string.h>
00053 #include <errno.h>
00054 #include <math.h>
00055 
00056 #include <libxml/tree.h>
00057 #include <libxml/parser.h>
00058 #include <libxml/valid.h>
00059 
00060 #ifdef HAPLO_HAVE_DMALLOC
00061 #include <dmalloc.h>
00062 #endif
00063 
00064 #include <jwsc/base/error.h>
00065 #include <jwsc/base/limits.h>
00066 #include <jwsc/base/file_io.h>
00067 #include <jwsc/vector/vector.h>
00068 #include <jwsc/vector/vector_io.h>
00069 #include <jwsc/vector/vector_math.h>
00070 #include <jwsc/matrix/matrix.h>
00071 #include <jwsc/matrix/matrix_io.h>
00072 #include <jwsc/matblock/matblock.h>
00073 #include <jwsc/matblock/matblock_io.h>
00074 #include <jwsc/stat/gmm.h>
00075 
00076 #include "xml.h"
00077 #include "haplo_groups.h"
00078 #include "nb_gmm.h"
00079 
00080 
00091 void train_nb_gmm_model
00092 (
00093     NB_gmm_model**    model_out,
00094     const Vector_u32* labels, 
00095     const Matrix_i32* markers,
00096     const Vector_d*   label_priors,
00097     uint32_t          num_components
00098 )
00099 {
00100     uint32_t  l, num_labels;
00101     uint32_t  s, num_samples;
00102     uint32_t  m, num_markers;
00103     uint32_t  n, N;
00104 
00105     Matrix_d***   mu;
00106     Matblock_d*** sigma;
00107     Vector_d***   pi;
00108     Vector_d*     priors;
00109     Matrix_d*     markers_d = NULL;
00110 
00111     assert(markers->num_rows == labels->num_elts);
00112 
00113     num_markers = markers->num_cols;
00114     num_samples = markers->num_rows;
00115     num_labels  = get_num_haplo_groups();
00116 
00117     if (*model_out != NULL)
00118     {
00119         free_nb_gmm_model(*model_out);
00120     }
00121 
00122     assert(mu    = malloc(num_labels*sizeof(Matrix_d*)));
00123     assert(sigma = malloc(num_labels*sizeof(Matblock_d*)));
00124     assert(pi    = malloc(num_labels*sizeof(Vector_d*)));
00125     priors = NULL;
00126 
00127     for (l = 0; l < num_labels; l++)
00128     {
00129         mu   [ l ] = NULL;
00130         sigma[ l ] = NULL;
00131         pi   [ l ] = NULL;
00132 
00133         N = 0;
00134         for (s = 0; s < num_samples; s++)
00135         {
00136             if (labels->elts[ s ] == l)
00137             {
00138                 N++;
00139             }
00140         }
00141 
00142         // TODO we should probably handle the case of N=1 by creating a
00143         // distribution with the sample as the mean and a very small diagonal
00144         // covariance matrix.
00145         if (N < num_components*2)
00146             continue;
00147 
00148         assert(mu[ l ]    = malloc(num_markers*sizeof(void*)));
00149         assert(sigma[ l ] = malloc(num_markers*sizeof(void*)));
00150         assert(pi[ l ]    = malloc(num_markers*sizeof(void*)));
00151 
00152         create_matrix_d(&markers_d, N, 1);
00153 
00154         for (m = 0; m < num_markers; m++)
00155         {
00156             n = 0;
00157             for (s = 0; s < num_samples; s++)
00158             {
00159                 if (labels->elts[ s ] == l)
00160                 {
00161                     markers_d->elts[ n ][ 0 ] = markers->elts[ s ][ m ];
00162                     n++;
00163                 }
00164             }
00165 
00166             mu[ l ][ m ]    = NULL;
00167             sigma[ l ][ m ] = NULL;
00168             pi[ l ][ m ]    = NULL;
00169 
00170             train_gmm_d(&(mu[l][m]), &(sigma[l][m]), &(pi[l][m]), NULL, 
00171                     markers_d, num_components);
00172         }
00173     }
00174 
00175     free_matrix_d(markers_d);
00176 
00177     if (label_priors == NULL)
00178     {
00179         create_zero_vector_d(&priors, num_labels);
00180 
00181         for (s = 0; s < num_samples; s++)
00182         {
00183             l = labels->elts[ s ];
00184             priors->elts[ l ]++;
00185         }
00186         normalize_vector_sum_d(&priors, priors);
00187     }
00188     else
00189     {
00190         normalize_vector_sum_d(&priors, label_priors);
00191     }
00192 
00193     assert(*model_out = malloc(sizeof(NB_gmm_model)));
00194     (*model_out)->num_markers    = num_markers;
00195     (*model_out)->num_components = num_components;
00196     (*model_out)->mu             = mu;
00197     (*model_out)->sigma          = sigma;
00198     (*model_out)->pi             = pi;
00199     (*model_out)->priors         = priors;
00200 }
00201 
00202 
00214 Error* predict_label_with_nb_gmm_model
00215 (
00216     uint32_t*           label_out,
00217     double*             confidence_out,
00218     const Vector_i32*   markers,
00219     const NB_gmm_model* model,
00220     uint32_t            order
00221 )
00222 {
00223     uint32_t num_markers;
00224     uint32_t marker;
00225     uint32_t num_labels;
00226     uint32_t label;
00227     uint32_t best_label;
00228     uint32_t i;
00229     double   log_map;
00230     double   log_ll;
00231     double   log_prior;
00232     double   log_posterior;
00233     double   posterior_sum;
00234 
00235     Vector_u32* best_labels     = NULL;
00236     Vector_d*  best_labels_conf = NULL;
00237     Matrix_d*  marker_val       = NULL;
00238 
00239     const Matrix_d*   mu;
00240     const Matblock_d* sigma;
00241     const Vector_d*   pi;
00242 
00243     num_labels = get_num_haplo_groups();
00244     num_markers = markers->num_elts;
00245 
00246     log_map = log(JWSC_MIN_LOG_ARG);
00247     posterior_sum = 0;
00248 
00249     create_init_vector_u32(&best_labels, num_labels, 0);
00250     create_init_vector_d(&best_labels_conf, num_labels, 0);
00251 
00252     create_matrix_d(&marker_val, 1, 1);
00253 
00254     for (label = 0; label < num_labels; label++)
00255     {
00256         if (!model->mu[ label ])
00257             continue;
00258 
00259         log_ll = 0;
00260 
00261         for (marker = 0; marker < num_markers; marker++)
00262         {
00263             marker_val->elts[0][0] = markers->elts[ marker ];
00264 
00265             /* Ignore non-positive marker values. */
00266             if (marker_val > 0)
00267             {
00268                 mu = model->mu[ label ][ marker ];
00269                 sigma = model->sigma[ label ][ marker ];
00270                 pi = model->pi[ label ][ marker ];
00271 
00272                 log_ll += gmm_log_likelihood_d(mu, sigma, pi, marker_val);
00273             }
00274         }
00275 
00276         if (model->priors->elts[ label ] < JWSC_MIN_LOG_ARG)
00277         {
00278             log_prior = log(JWSC_MIN_LOG_ARG);
00279         }
00280         else
00281         {
00282             log_prior = log(model->priors->elts[ label ]);
00283         }
00284 
00285         log_posterior = log_ll + log_prior;
00286         posterior_sum += exp(log_posterior);
00287 
00288         if (log_posterior > log_map)
00289         {
00290             log_map = log_posterior;
00291             best_label = label;
00292             for (i = num_labels-1; i > 0; i--)
00293             {
00294                 best_labels->elts[ i ] = best_labels->elts[ i - 1];
00295                 best_labels_conf->elts[ i ] = best_labels_conf->elts[ i - 1];
00296             }
00297             best_labels->elts[ 0 ] = best_label;
00298             best_labels_conf->elts[ 0 ] = log_map;
00299         }
00300     }
00301 
00302     assert(order < best_labels->num_elts);
00303 
00304     *label_out = best_labels->elts[ order ];
00305 
00306     if (posterior_sum > 0)
00307     {
00308         assert(finite(*confidence_out = exp(best_labels_conf->elts[order]) / 
00309             (double)posterior_sum));
00310     }
00311     else
00312     {
00313         *confidence_out = 0;
00314     }
00315 
00316     free_vector_u32(best_labels);
00317     free_vector_d(best_labels_conf);
00318     free_matrix_d(marker_val);
00319 
00320     return NULL;
00321 }
00322 
00323 
00337 Error* predict_labels_with_nb_gmm_model
00338 (
00339     Vector_u32**        labels_out,
00340     Vector_d**          confidence_out,
00341     const Matrix_i32*   markers,
00342     const NB_gmm_model* model
00343 )
00344 {
00345     uint32_t num_samples;
00346     uint32_t sample;
00347     uint32_t num_markers;
00348     uint32_t marker;
00349 
00350     Vector_u32* labels;
00351     Vector_d*   confidence;
00352     Vector_i32* markers_v;
00353     Error*      e;
00354 
00355     num_samples = markers->num_rows;
00356     num_markers = markers->num_cols;
00357 
00358     create_vector_u32(labels_out, num_samples);
00359     labels = *labels_out;
00360 
00361     create_vector_d(confidence_out, num_samples);
00362     confidence = *confidence_out;
00363 
00364     markers_v = NULL;
00365     create_vector_i32(&markers_v, num_markers);
00366 
00367     for (sample = 0; sample < num_samples; sample++)
00368     {
00369         for (marker = 0; marker < num_markers; marker++)
00370         {
00371             markers_v->elts[ marker ] = markers->elts[ sample ][marker];
00372         }
00373 
00374         if ((e = predict_label_with_nb_gmm_model(&(labels->elts[ sample ]), 
00375                 &(confidence->elts[ sample ]), markers_v, model, 0)) != NULL)
00376         {
00377             return e;
00378         }
00379     }
00380 
00381     free_vector_i32(markers_v);
00382 
00383     return NULL;
00384 }
00385 
00386 
00395 Error* read_nb_gmm_model(NB_gmm_model** model_out, const char* fname)
00396 {
00397     uint32_t      num_labels;
00398     uint32_t      i, j;
00399     FILE*         fp;
00400     NB_gmm_model* model;
00401     Error*        e;
00402     int           slen = 256;
00403     char          str[slen];
00404 
00405     num_labels = get_num_haplo_groups();
00406 
00407     if ((fp = fopen(fname, "r")) == NULL)
00408     {
00409         snprintf(str, slen, "%s: %s", fname, strerror(errno));
00410         return JWSC_EARG(str);
00411     }
00412 
00413     if (*model_out == NULL)
00414     {
00415         assert((*model_out = malloc(sizeof(NB_gmm_model))) != NULL);
00416         model = *model_out;
00417         model->priors = NULL;
00418     }
00419 
00420     if (fscanf(fp, "%u\n%u\n", &(model->num_markers), 
00421                 &(model->num_components)) != 2)
00422     {
00423         snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file");
00424         return JWSC_EARG(str);
00425     }
00426 
00427     assert(model->mu    = malloc(num_labels*sizeof(void*)));
00428     assert(model->sigma = malloc(num_labels*sizeof(void*)));
00429     assert(model->pi    = malloc(num_labels*sizeof(void*)));
00430 
00431     for (i = 0; i < num_labels; i++)
00432     {
00433         if (fscanf(fp, "%u", (uint32_t*)(&(model->mu[ i ]))) != 1)
00434         {
00435             return JWSC_EARG("Improperly formatted model file");
00436         }
00437     }
00438 
00439     for (i = 0; i < num_labels; i++)
00440     {
00441         model->sigma[ i ] = NULL;
00442         model->pi[ i ]    = NULL;
00443 
00444         if (model->mu[ i ])
00445         {
00446             assert(model->mu[i]    = malloc(model->num_markers*sizeof(void*)));
00447             assert(model->sigma[i] = malloc(model->num_markers*sizeof(void*)));
00448             assert(model->pi[i]    = malloc(model->num_markers*sizeof(void*)));
00449 
00450             for (j = 0; j < model->num_markers; j++)
00451             {
00452                 model->mu[ i ][ j ]    = NULL;
00453                 model->sigma[ i ][ j ] = NULL;
00454                 model->pi[ i ][ j ]    = NULL;
00455 
00456                 if ((e = read_matrix_with_header_fp_d(
00457                                 &(model->mu[i][j]), fp)) ||
00458                     (e = read_matblock_with_header_fp_d(
00459                                 &(model->sigma[i][j]), fp)) ||
00460                     (e = read_vector_with_header_fp_d(
00461                                 &(model->pi[i][j]), fp)))
00462                 {
00463                     return JWSC_EARG("Improperly formatted model file");
00464                 }
00465             }
00466         }
00467     }
00468 
00469     if ((e = read_vector_with_header_fp_d(&(model->priors), fp)) != NULL)
00470     {
00471         snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file");
00472         return JWSC_EARG(str);
00473     }
00474 
00475     if (fclose(fp) != 0)
00476     {
00477         snprintf(str, slen, "%s: %s", fname, strerror(errno));
00478         return JWSC_EARG(str);
00479     }
00480 
00481     return NULL;
00482 }
00483 
00484 
00491 Error* write_nb_gmm_model(NB_gmm_model* model, const char* fname)
00492 {
00493     uint32_t num_labels;
00494     uint32_t i, j;
00495     FILE*    fp;
00496     int      slen = 256;
00497     char     str[slen];
00498 
00499     num_labels = get_num_haplo_groups();
00500 
00501     if ((fp = fopen(fname, "w")) == NULL)
00502     {
00503         snprintf(str, slen, "%s: %s", fname, strerror(errno));
00504         return JWSC_EARG(str);
00505     }
00506 
00507     fprintf(fp, "%u\n", model->num_markers);
00508     fprintf(fp, "%u\n", model->num_components);
00509 
00510     for (i = 0; i < num_labels; i++)
00511     {
00512         fprintf(fp, "%u ", (model->mu[ i ] != NULL));
00513     }
00514     fprintf(fp, "\n");
00515 
00516     for (i = 0; i < num_labels; i++)
00517     {
00518         if (model->mu[ i ])
00519         {
00520             for (j = 0; j < model->num_markers; j++)
00521             {
00522                 write_matrix_with_header_fp_d(model->mu[ i ][ j ], fp);
00523                 write_matblock_with_header_fp_d(model->sigma[ i ][ j ], fp);
00524                 write_vector_with_header_fp_d(model->pi[ i ][ j ], fp);
00525             }
00526         }
00527     }
00528 
00529     write_vector_with_header_fp_d(model->priors, fp);
00530 
00531     if (fclose(fp) != 0)
00532     {
00533         snprintf(str, slen, "%s: %s", fname, strerror(errno));
00534         return JWSC_EARG(str);
00535     }
00536 
00537     return NULL;
00538 }
00539 
00540 
00542 void free_nb_gmm_model(NB_gmm_model* model)
00543 {
00544     uint32_t i, j;
00545     uint32_t num_labels;
00546 
00547     if (!model)
00548         return;
00549 
00550     num_labels = get_num_haplo_groups();
00551 
00552     for (i = 0; i < num_labels; i++)
00553     {
00554         if (model->mu[ i ])
00555         {
00556             for (j = 0; j < model->num_markers; j++)
00557             {
00558                 free_matrix_d(model->mu[ i ][ j ]);
00559                 free_matblock_d(model->sigma[ i ][ j ]);
00560                 free_vector_d(model->pi[ i ][ j ]);
00561             }
00562             free(model->mu[ i ]);
00563             free(model->sigma[ i ]);
00564             free(model->pi[ i ]);
00565         }
00566     }
00567     free(model->mu);
00568     free(model->sigma);
00569     free(model->pi);
00570     free_vector_d(model->priors);
00571     free(model);
00572 }
00573 
00574 
00576 static Error* read_nb_gmm_xml_doc
00577 (
00578     xmlDoc**    xml_doc_out,
00579     const char* xml_fname,
00580     const char* dtd_fname
00581 )
00582 {
00583     xmlParserCtxt* xml_parse_ctxt;
00584     xmlValidCtxt*  xml_valid_ctxt;
00585     xmlDtd*        xml_dtd;
00586     int            slen = 256;
00587     char           str[slen];
00588 
00589     assert(xml_parse_ctxt = xmlNewParserCtxt());
00590 
00591     if (!(*xml_doc_out = xmlCtxtReadFile(xml_parse_ctxt, xml_fname, NULL, 0)))
00592     {
00593         snprintf(str, slen, "%s: %s", xml_fname, "Could not parse file");
00594         return JWSC_EARG(str);
00595     } 
00596 
00597     xmlFreeParserCtxt(xml_parse_ctxt);
00598 
00599     if (dtd_fname)
00600     {
00601         assert(xml_valid_ctxt = xmlNewValidCtxt());
00602 
00603         if (!(xml_dtd = xmlParseDTD(NULL, (xmlChar*)dtd_fname)))
00604         {
00605             snprintf(str, slen, "%s: %s", dtd_fname, "Could not parse DTD");
00606             return JWSC_EARG(str);
00607         }
00608 
00609         if (!xmlValidateDtd(xml_valid_ctxt, *xml_doc_out, xml_dtd))
00610         {
00611             snprintf(str, slen, "%s: %s", xml_fname, "XML file not valid");
00612             return JWSC_EARG(str);
00613         }
00614 
00615         xmlFreeValidCtxt(xml_valid_ctxt);
00616         xmlFreeDtd(xml_dtd);
00617     }
00618 
00619     return NULL;
00620 }
00621 
00622 
00629 static Error* create_nb_gmm_model_tree_from_xml_node
00630 (
00631      NB_gmm_model_tree**      tree_out, 
00632      const NB_gmm_model_tree* parent,
00633      uint32_t                 parent_label,
00634      xmlNode*                 xml_node
00635 )
00636 {
00637     NB_gmm_model_tree* tree;
00638     uint32_t i, j;
00639     uint32_t label;
00640     uint32_t altlabel;
00641     uint32_t num_altlabels;
00642     double   p;
00643     xmlAttr* attr;
00644     xmlNode* it;
00645     xmlNode* itt;
00646     Error*   err;
00647     const char* fname;
00648 
00649     if (*tree_out)
00650     {
00651         free_nb_gmm_model_tree(*tree_out);
00652     }
00653 
00654     assert(*tree_out = malloc(sizeof(NB_gmm_model_tree)));
00655     tree = *tree_out;
00656 
00657     tree->parent       = parent;
00658     tree->parent_label = parent_label;
00659     tree->num_groups   = 0;
00660     tree->subtrees     = NULL;
00661     tree->labels       = NULL;
00662     tree->altlabels    = NULL;
00663     tree->priors       = NULL;
00664     tree->model        = NULL;
00665     tree->model_fname  = NULL;
00666 
00667     assert(XMLStrEqual(xml_node->name, "model"));
00668 
00669     for (it = xml_node->children; it; it = it->next)
00670     {
00671         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "group"))
00672         {
00673             tree->num_groups++;
00674         }
00675     }
00676 
00677     assert(tree->subtrees = calloc(tree->num_groups, sizeof(void*)));
00678     assert(tree->altlabels = calloc(tree->num_groups, sizeof(void*)));
00679     create_vector_u32(&(tree->labels), tree->num_groups);
00680 
00681     for (attr = xml_node->properties; attr; attr = attr->next)
00682     {
00683         if (XMLStrEqual(attr->name, "priors") && 
00684                 XMLStrEqual(attr->children->content, "true"))
00685         {
00686             create_init_vector_d(&(tree->priors), get_num_haplo_groups(), 1.0);
00687             break;
00688         }
00689     }
00690 
00691     i = 0;
00692     for (it = xml_node->children; it; it = it->next)
00693     {
00694         j = 0;
00695         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "file"))
00696         {
00697             fname = (const char*) it->children->content;
00698             tree->model_fname = malloc((strlen(fname)+1) * sizeof(char));
00699             strcpy(tree->model_fname, fname);
00700         }
00701         else if (it->type == XML_ELEMENT_NODE && 
00702                 XMLStrEqual(it->name, "mixture-components"))
00703         {
00704             if (sscanf((char*)it->children->content, "%d", 
00705                         &(tree->num_components)) != 1 || 
00706                     tree->num_components < 1)
00707             {
00708                 return JWSC_EARG("Invalid number of mixture components");
00709             }
00710         }
00711         else if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "group"))
00712         {
00713             num_altlabels = 0;
00714             for (itt = it->children; itt; itt = itt->next)
00715             {
00716                 if (itt->type == XML_ELEMENT_NODE &&
00717                         XMLStrEqual(itt->name, "altlabel"))
00718                 {
00719                     num_altlabels++;
00720                 }
00721             }
00722             if (num_altlabels)
00723             {
00724                 create_vector_u32(&(tree->altlabels[ i ]), num_altlabels);
00725             }
00726 
00727             for (itt = it->children; itt; itt = itt->next)
00728             {
00729                 if (itt->type == XML_ELEMENT_NODE &&
00730                         XMLStrEqual(itt->name, "label"))
00731                 {
00732                     if ((err = lookup_haplo_group_index_from_label(&label, 
00733                                     (const char*) itt->children->content)))
00734                     {
00735                         return err;
00736                     }
00737                     tree->labels->elts[ i ] = label;
00738                 }
00739                 else if (itt->type == XML_ELEMENT_NODE &&
00740                         XMLStrEqual(itt->name, "altlabel"))
00741                 {
00742                     if ((err = lookup_haplo_group_index_from_label(&altlabel, 
00743                                     (const char*) itt->children->content)))
00744                     {
00745                         return err;
00746                     }
00747                     tree->altlabels[ i ]->elts[ j++ ] = altlabel;
00748                 }
00749                 else if (tree->priors && itt->type == XML_ELEMENT_NODE &&
00750                         XMLStrEqual(itt->name, "prior"))
00751                 {
00752                     if (sscanf((char*)itt->children->content, "%lf", &p) != 1 ||
00753                             p < 0 || p > 1)
00754                     {
00755                         return JWSC_EARG("Invalid prior");
00756                     }
00757                     tree->priors->elts[ i ] = p;
00758                 }
00759                 else if (itt->type == XML_ELEMENT_NODE &&
00760                         XMLStrEqual(itt->name, "model"))
00761                 {
00762                     if ((err = create_nb_gmm_model_tree_from_xml_node(
00763                                  &(tree->subtrees[ i ]), tree, label, itt)))
00764                     {
00765                         return err;
00766                     }
00767                 }
00768             }
00769             i++;
00770         }
00771     }
00772     assert(i == tree->num_groups);
00773 
00774     return NULL;
00775 }
00776 
00777 
00782 static Error* create_model_training_data
00783 (
00784     Vector_u32**      train_labels_out,
00785     Matrix_i32**      train_markers_out,
00786     const Vector_u32* data_labels,
00787     const Matrix_i32* data_markers,
00788     const Vector_u32* model_labels,
00789     Vector_u32*const* model_altlabels
00790 )
00791 {
00792     uint8_t     b;
00793     uint32_t    i, j, k;
00794     uint32_t    n;
00795     Vector_u32* train_labels = NULL;
00796     Matrix_i32* train_markers = NULL;
00797 
00798     copy_vector_u32(&train_labels, data_labels);
00799     copy_matrix_i32(&train_markers, data_markers);
00800 
00801     n = 0;
00802     for (i = 0; i < data_labels->num_elts; i++)
00803     {
00804         for (j = 0; j < model_labels->num_elts; j++)
00805         {
00806             if (is_ancestor(data_labels->elts[ i ], model_labels->elts[ j ]))
00807             {
00808                 train_labels->elts[ n ] = model_labels->elts[ j ];
00809                 copy_matrix_block_into_matrix_i32(train_markers, n, 0,
00810                         data_markers, i, 0, 1, data_markers->num_cols);
00811                 n++;
00812                 break;
00813             }
00814             else if (model_altlabels[ j ])
00815             {
00816                 for (k = 0; k < model_altlabels[ j ]->num_elts; k++)
00817                 {
00818                     if ((b = is_ancestor(data_labels->elts[ i ], 
00819                                 model_altlabels[ j ]->elts[ k ])))
00820                     {
00821                         train_labels->elts[ n ] = model_labels->elts[ j ];
00822                         copy_matrix_block_into_matrix_i32(train_markers, n, 0,
00823                                 data_markers, i, 0, 1, data_markers->num_cols);
00824                         n++;
00825                         break;
00826                     }
00827                 }
00828                 if (b)
00829                 {
00830                     break;
00831                 }
00832             }
00833         }
00834     }
00835 
00836     if (!n)
00837     {
00838         return JWSC_EARG("No data for model");
00839     }
00840 
00841     copy_vector_section_u32(train_labels_out, train_labels, 0, n);
00842     copy_matrix_block_i32(train_markers_out, train_markers, 0, 0, n, 
00843             train_markers->num_cols);
00844 
00845     free_vector_u32(train_labels);
00846     free_matrix_i32(train_markers);
00847 
00848     return NULL;
00849 }
00850 
00851 
00853 static Error* train_nb_gmm_model_node
00854 (
00855     NB_gmm_model_node* node,
00856     const Vector_u32*  labels, 
00857     const Matrix_i32*  markers
00858 )
00859 {
00860     uint32_t i;
00861     Vector_u32* node_labels  = NULL;
00862     Matrix_i32* node_markers = NULL;
00863     Error*  err;
00864     int     slen = 256;
00865     char    str[slen];
00866 
00867     if ((err = create_model_training_data(&node_labels, &node_markers, labels,
00868                     markers, node->labels, node->altlabels)))
00869     {
00870         snprintf(str, slen, "%s: %s", node->model_fname, err->msg);
00871         return JWSC_EARG(str);
00872     }
00873 
00874     train_nb_gmm_model(&(node->model), node_labels, node_markers, 
00875             node->priors, node->num_components);
00876 
00877     for (i = 0; i < node->num_groups; i++)
00878     {
00879         if (node->subtrees[ i ])
00880         {
00881             if ((err = train_nb_gmm_model_node(node->subtrees[ i ], labels,
00882                         markers)))
00883             {
00884                 return err;
00885             }
00886         }
00887     }
00888 
00889     free_vector_u32(node_labels);
00890     free_matrix_i32(node_markers);
00891 
00892     return NULL;
00893 }
00894 
00895 
00903 Error* train_nb_gmm_model_tree
00904 (
00905     NB_gmm_model_tree** tree_out,
00906     const Vector_u32*   labels, 
00907     const Matrix_i32*   markers,
00908     const char*         tree_xml_fname,
00909     const char*         tree_dtd_fname
00910 )
00911 {
00912     xmlDoc*  xml_doc;
00913     xmlNode* xml_node;
00914     xmlNode* it;
00915     Error*   err;
00916     int      slen = 256;
00917     char     str[slen];
00918 
00919     assert(tree_xml_fname);
00920 
00921     if ((err = read_nb_gmm_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname)))
00922     {
00923         return err;
00924     }
00925 
00926     for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next)
00927     {
00928         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model"))
00929         {
00930             xml_node = it;
00931             break;
00932         }
00933     }
00934 
00935     if ((err = create_nb_gmm_model_tree_from_xml_node(tree_out, NULL, 0, 
00936                     xml_node)))
00937     {
00938         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
00939         return JWSC_EARG(str);
00940     }
00941 
00942     if ((err = train_nb_gmm_model_node(*tree_out, labels, markers)))
00943     {
00944         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
00945         return JWSC_EARG(str);
00946     }
00947 
00948     return NULL;
00949 }
00950 
00951 
00953 static Error* recursively_predict_label_in_model_tree
00954 (
00955     uint32_t*                label_out,
00956     double*                  confidence_out,
00957     const NB_gmm_model_tree* tree,
00958     const Vector_i32*        markers_v,
00959     uint32_t                 order
00960 )
00961 {
00962     uint32_t subtree_label;
00963     uint32_t i;
00964     double   confidence;
00965     Error*   err;
00966 
00967     if ((err = predict_label_with_nb_gmm_model(label_out, &confidence, 
00968                     markers_v, tree->model, order)) != NULL)
00969     {
00970         return err;
00971     }
00972     *confidence_out *= confidence;
00973 
00974     for (i = 0; i < tree->num_groups; i++)
00975     {
00976         if (tree->subtrees[ i ])
00977         {
00978             subtree_label = tree->subtrees[ i ]->parent_label;
00979 
00980             if (subtree_label == *label_out)
00981             {
00982                 if ((err = recursively_predict_label_in_model_tree(label_out,
00983                                 confidence_out, tree->subtrees[ i ], markers_v,
00984                                 order)) != NULL)
00985                 {
00986                     return err;
00987                 }
00988             }
00989         }
00990     }
00991 
00992     return NULL;
00993 }
00994 
00995 
01011 Error* predict_labels_with_nb_gmm_model_tree
01012 (
01013     Vector_u32**             labels_out,
01014     Vector_d**               confidence_out,
01015     const Matrix_i32*        markers,
01016     const NB_gmm_model_tree* tree,
01017     uint32_t                 order
01018 )
01019 {
01020     uint32_t s, num_samples;
01021     uint32_t m, num_markers;
01022 
01023     Vector_u32* labels;
01024     Vector_d*   confidence;
01025     Vector_i32* markers_v;
01026     Error*      err;
01027 
01028     num_samples = markers->num_rows;
01029     num_markers = markers->num_cols;
01030 
01031     create_vector_u32(labels_out, num_samples);
01032     labels = *labels_out;
01033 
01034     create_init_vector_d(confidence_out, num_samples, 1.0);
01035     confidence = *confidence_out;
01036 
01037     markers_v = NULL;
01038     create_vector_i32(&markers_v, num_markers);
01039 
01040     for (s = 0; s < num_samples; s++)
01041     {
01042         for (m = 0; m < num_markers; m++)
01043         {
01044             markers_v->elts[ m ]= markers->elts[ s ][ m ];
01045         }
01046 
01047         if ((err = recursively_predict_label_in_model_tree(
01048                         &(labels->elts[ s ]), 
01049                         &(confidence->elts[ s ]), tree, markers_v, order)) 
01050                 != NULL)
01051         {
01052             return err;
01053         }
01054     }
01055 
01056     free_vector_i32(markers_v);
01057 
01058     return NULL;
01059 }
01060 
01061 
01063 static Error* read_nb_gmm_model_node
01064 (
01065      NB_gmm_model_node* node, 
01066      const char*        model_dirname
01067 )
01068 {
01069     uint32_t    i;
01070     char        buf[1024] = {0};
01071     Error*      err;
01072 
01073     snprintf(buf, 1024, "%s/%s", model_dirname, node->model_fname);
01074     if ((err = read_nb_gmm_model(&(node->model), buf)))
01075     {
01076         return err;
01077     }
01078 
01079     for (i = 0; i < node->num_groups; i++)
01080     {
01081         if (node->subtrees[ i ])
01082         {
01083             if ((err = read_nb_gmm_model_node(node->subtrees[ i ], 
01084                             model_dirname)))
01085             {
01086                 return err;
01087             }
01088         }
01089     }
01090 
01091     return NULL;
01092 }
01093 
01094 
01101 Error* read_nb_gmm_model_tree
01102 (
01103     NB_gmm_model_tree** tree_out,
01104     const char*         tree_xml_fname,
01105     const char*         tree_dtd_fname,
01106     const char*         model_dirname
01107 )
01108 {
01109     xmlDoc*  xml_doc;
01110     xmlNode* xml_node;
01111     xmlNode* it;
01112     Error*   err;
01113     int      slen = 256;
01114     char     str[slen];
01115 
01116     assert(tree_xml_fname && model_dirname);
01117 
01118     if ((err = read_nb_gmm_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname)))
01119     {
01120         return err;
01121     }
01122 
01123     for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next)
01124     {
01125         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model"))
01126         {
01127             xml_node = it;
01128             break;
01129         }
01130     }
01131 
01132     if ((err = create_nb_gmm_model_tree_from_xml_node(tree_out, NULL, 0,
01133                     xml_node)))
01134     {
01135         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
01136         return JWSC_EARG(str);
01137     }
01138 
01139     if ((err = read_nb_gmm_model_node(*tree_out, model_dirname)))
01140     {
01141         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
01142         return JWSC_EARG(str);
01143     }
01144 
01145     return NULL;
01146 }
01147 
01148 
01153 Error* write_nb_gmm_model_tree
01154 (
01155     const NB_gmm_model_tree* tree,
01156     const char*              model_dirname
01157 )
01158 {
01159     uint32_t i;
01160     char     buf[1024] = {0};
01161     Error*   err;
01162     int      slen = 256;
01163     char     str[slen];
01164 
01165     snprintf(buf, 1024, "%s/%s", model_dirname, tree->model_fname);
01166     if ((err = write_nb_gmm_model(tree->model, buf)))
01167     {
01168         return err;
01169     }
01170 
01171     for (i = 0; i < tree->num_groups; i++)
01172     {
01173         if (tree->subtrees[ i ])
01174         {
01175             if ((err = write_nb_gmm_model_tree(tree->subtrees[ i ], 
01176                             model_dirname)))
01177             {
01178                 snprintf(str, slen, "%s: %s", tree->model_fname, err->msg);
01179                 return JWSC_EARG(str);
01180             }
01181         }
01182     }
01183 
01184     return NULL;
01185 }
01186 
01187 
01189 void free_nb_gmm_model_tree(NB_gmm_model_tree* tree)
01190 {
01191     uint32_t i;
01192 
01193     if (!tree)
01194         return;
01195 
01196     for (i = 0; i < tree->num_groups; i++)
01197     {
01198         free_nb_gmm_model_tree(tree->subtrees[ i ]);
01199         free_vector_u32(tree->altlabels[ i ]);
01200     }
01201 
01202     free(tree->subtrees);
01203     free(tree->altlabels);
01204     free_vector_u32(tree->labels);
01205     free_vector_d(tree->priors);
01206     free_nb_gmm_model(tree->model);
01207     free(tree->model_fname);
01208     free(tree);
01209 }