Haplo Prediction
predict haplogroups
mv_gmm.c
Go to the documentation of this file.
00001 /*
00002  * This work is licensed under a Creative Commons 
00003  * Attribution-Noncommercial-Share Alike 3.0 United States License.
00004  * 
00005  *    http://creativecommons.org/licenses/by-nc-sa/3.0/us/
00006  * 
00007  * You are free:
00008  * 
00009  *    to Share - to copy, distribute, display, and perform the work
00010  *    to Remix - to make derivative works
00011  * 
00012  * Under the following conditions:
00013  * 
00014  *    Attribution. You must attribute the work in the manner specified by the
00015  *    author or licensor (but not in any way that suggests that they endorse you
00016  *    or your use of the work).
00017  * 
00018  *    Noncommercial. You may not use this work for commercial purposes.
00019  * 
00020  *    Share Alike. If you alter, transform, or build upon this work, you may
00021  *    distribute the resulting work only under the same or similar license to
00022  *    this one.
00023  * 
00024  * For any reuse or distribution, you must make clear to others the license
00025  * terms of this work. The best way to do this is by including this header.
00026  * 
00027  * Any of the above conditions can be waived if you get permission from the
00028  * copyright holder.
00029  * 
00030  * Apart from the remix rights granted under this license, nothing in this
00031  * license impairs or restricts the author's moral rights.
00032  */
00033 
00034 
00046 #include <config.h>
00047 
00048 #include <stdlib.h>
00049 #include <stdio.h>
00050 #include <inttypes.h>
00051 #include <assert.h>
00052 #include <string.h>
00053 #include <errno.h>
00054 #include <math.h>
00055 
00056 #include <libxml/tree.h>
00057 #include <libxml/parser.h>
00058 #include <libxml/valid.h>
00059 
00060 #ifdef HAPLO_HAVE_DMALLOC
00061 #include <dmalloc.h>
00062 #endif
00063 
00064 #include <jwsc/base/error.h>
00065 #include <jwsc/base/limits.h>
00066 #include <jwsc/base/file_io.h>
00067 #include <jwsc/vector/vector.h>
00068 #include <jwsc/vector/vector_io.h>
00069 #include <jwsc/vector/vector_math.h>
00070 #include <jwsc/matrix/matrix.h>
00071 #include <jwsc/matrix/matrix_io.h>
00072 #include <jwsc/matblock/matblock.h>
00073 #include <jwsc/matblock/matblock_io.h>
00074 #include <jwsc/stat/gmm.h>
00075 
00076 #include "xml.h"
00077 #include "haplo_groups.h"
00078 #include "input.h"
00079 #include "mv_gmm.h"
00080 
00081 
00092 void train_mv_gmm_model
00093 (
00094     MV_gmm_model**    model_out,
00095     const Vector_u32* labels, 
00096     const Matrix_i32* markers,
00097     const Vector_d*   label_priors,
00098     uint32_t          num_components
00099 )
00100 {
00101     uint32_t  l, num_labels;
00102     uint32_t  s, num_samples;
00103     uint32_t  m, num_markers;
00104     uint32_t  n, N;
00105 
00106     Matrix_d**   mu;
00107     Matblock_d** sigma;
00108     Vector_d**   pi;
00109     Vector_d*    priors;
00110     Matrix_d*    markers_d = NULL;
00111 
00112     assert(markers->num_rows == labels->num_elts);
00113 
00114     num_markers = markers->num_cols;
00115     num_samples = markers->num_rows;
00116     num_labels  = get_num_haplo_groups();
00117 
00118     if (*model_out != NULL)
00119     {
00120         free_mv_gmm_model(*model_out);
00121     }
00122 
00123     assert(mu    = malloc(num_labels*sizeof(Matrix_d*)));
00124     assert(sigma = malloc(num_labels*sizeof(Matblock_d*)));
00125     assert(pi    = malloc(num_labels*sizeof(Vector_d*)));
00126     priors = NULL;
00127 
00128     for (l = 0; l < num_labels; l++)
00129     {
00130         mu   [ l ] = NULL;
00131         sigma[ l ] = NULL;
00132         pi   [ l ] = NULL;
00133 
00134         N = 0;
00135         for (s = 0; s < num_samples; s++)
00136         {
00137             if (labels->elts[ s ] == l)
00138             {
00139                 N++;
00140             }
00141         }
00142 
00143         // TODO we should probably handle the case of N=1 by creating a
00144         // distribution with the sample as the mean and a very small diagonal
00145         // covariance matrix.
00146         if (N < num_components*2)
00147             continue;
00148 
00149         create_matrix_d(&markers_d, N, num_markers);
00150 
00151         n = 0;
00152         for (s = 0; s < num_samples; s++)
00153         {
00154             if (labels->elts[ s ] == l)
00155             {
00156                 for (m = 0; m < num_markers; m++)
00157                 {
00158                     markers_d->elts[ n ][ m ] = markers->elts[ s ][ m ];
00159                 }
00160                 n++;
00161             }
00162         }
00163 
00164         train_gmm_d(&(mu[l]), &(sigma[l]), &(pi[l]), NULL, markers_d,
00165                 num_components);
00166     }
00167 
00168     free_matrix_d(markers_d);
00169 
00170     if (label_priors == NULL)
00171     {
00172         create_zero_vector_d(&priors, num_labels);
00173 
00174         for (s = 0; s < num_samples; s++)
00175         {
00176             l = labels->elts[ s ];
00177             priors->elts[ l ]++;
00178         }
00179         normalize_vector_sum_d(&priors, priors);
00180     }
00181     else
00182     {
00183         normalize_vector_sum_d(&priors, label_priors);
00184     }
00185 
00186     assert(*model_out = malloc(sizeof(MV_gmm_model)));
00187     (*model_out)->num_markers    = num_markers;
00188     (*model_out)->num_components = num_components;
00189     (*model_out)->mu             = mu;
00190     (*model_out)->sigma          = sigma;
00191     (*model_out)->pi             = pi;
00192     (*model_out)->priors         = priors;
00193 }
00194 
00195 
00207 Error* predict_label_with_mv_gmm_model
00208 (
00209     uint32_t*           label_out,
00210     double*             confidence_out,
00211     const Vector_i32*   markers,
00212     const MV_gmm_model* model,
00213     uint32_t            order
00214 )
00215 {
00216     uint32_t num_markers;
00217     uint32_t marker;
00218     uint32_t num_labels;
00219     uint32_t label;
00220     uint32_t best_label;
00221     uint32_t i;
00222     double   log_map;
00223     double   log_ll;
00224     double   log_prior;
00225     double   log_posterior;
00226     double   posterior_sum;
00227 
00228     Vector_u32* best_labels     = NULL;
00229     Vector_d*  best_labels_conf = NULL;
00230     Matrix_d*  marker_vals      = NULL;
00231 
00232     const Matrix_d*   mu;
00233     const Matblock_d* sigma;
00234     const Vector_d*   pi;
00235 
00236     num_labels = get_num_haplo_groups();
00237     num_markers = markers->num_elts;
00238 
00239     log_map = log(0);
00240     posterior_sum = 0;
00241 
00242     create_init_vector_u32(&best_labels, num_labels, 0);
00243     create_init_vector_d(&best_labels_conf, num_labels, 0);
00244 
00245     create_matrix_d(&marker_vals, 1, num_markers);
00246 
00247     for (label = 0; label < num_labels; label++)
00248     {
00249         if (!model->mu[ label ])
00250             continue;
00251 
00252         for (marker = 0; marker < num_markers; marker++)
00253         {
00254             if (markers->elts[ marker ] <= 0)
00255             {
00256                 // Impute missing value with the mean.
00257                 marker_vals->elts[0][ marker ] = 0;
00258                 for (i = 0; i < model->num_components; i++)
00259                 {
00260                     marker_vals->elts[0][ marker ] += 
00261                         model->mu[ label ]->elts[ i ][ marker ] * 
00262                         model->pi[ label ]->elts[ i ];
00263                 }
00264             }
00265             else
00266             {
00267                 marker_vals->elts[0][ marker ] = markers->elts[ marker ];
00268             }
00269         }
00270 
00271         mu = model->mu[ label ];
00272         sigma = model->sigma[ label ];
00273         pi = model->pi[ label ];
00274 
00275         log_ll = gmm_log_likelihood_d(mu, sigma, pi, marker_vals);
00276 
00277         if (model->priors->elts[ label ] < JWSC_MIN_LOG_ARG)
00278         {
00279             log_prior = log(JWSC_MIN_LOG_ARG);
00280         }
00281         else
00282         {
00283             log_prior = log(model->priors->elts[ label ]);
00284         }
00285 
00286         log_posterior = log_ll + log_prior;
00287         posterior_sum += exp(log_posterior);
00288 
00289         if (log_posterior > log_map)
00290         {
00291             log_map = log_posterior;
00292             best_label = label;
00293             for (i = num_labels-1; i > 0; i--)
00294             {
00295                 best_labels->elts[ i ] = best_labels->elts[ i - 1];
00296                 best_labels_conf->elts[ i ] = best_labels_conf->elts[ i - 1];
00297             }
00298             best_labels->elts[ 0 ] = best_label;
00299             best_labels_conf->elts[ 0 ] = log_map;
00300         }
00301     }
00302 
00303     assert(order < best_labels->num_elts);
00304 
00305     *label_out = best_labels->elts[ order ];
00306 
00307     if (posterior_sum > 0)
00308     {
00309         assert(finite(*confidence_out = exp(best_labels_conf->elts[order]) / 
00310             (double)posterior_sum));
00311     }
00312     else
00313     {
00314         *confidence_out = 0;
00315     }
00316 
00317     free_vector_u32(best_labels);
00318     free_vector_d(best_labels_conf);
00319     free_matrix_d(marker_vals);
00320 
00321     return NULL;
00322 }
00323 
00324 
00338 Error* predict_labels_with_mv_gmm_model
00339 (
00340     Vector_u32**        labels_out,
00341     Vector_d**          confidence_out,
00342     const Matrix_i32*   markers,
00343     const MV_gmm_model* model
00344 )
00345 {
00346     uint32_t num_samples;
00347     uint32_t sample;
00348     uint32_t num_markers;
00349     uint32_t marker;
00350 
00351     Vector_u32* labels;
00352     Vector_d*   confidence;
00353     Vector_i32* markers_v;
00354     Error*      e;
00355 
00356     num_samples = markers->num_rows;
00357     num_markers = markers->num_cols;
00358 
00359     create_vector_u32(labels_out, num_samples);
00360     labels = *labels_out;
00361 
00362     create_vector_d(confidence_out, num_samples);
00363     confidence = *confidence_out;
00364 
00365     markers_v = NULL;
00366     create_vector_i32(&markers_v, num_markers);
00367 
00368     for (sample = 0; sample < num_samples; sample++)
00369     {
00370         for (marker = 0; marker < num_markers; marker++)
00371         {
00372             markers_v->elts[ marker ] = markers->elts[ sample ][marker];
00373         }
00374 
00375         if ((e = predict_label_with_mv_gmm_model(&(labels->elts[ sample ]), 
00376                 &(confidence->elts[ sample ]), markers_v, model, 0)) != NULL)
00377         {
00378             return e;
00379         }
00380     }
00381 
00382     free_vector_i32(markers_v);
00383 
00384     return NULL;
00385 }
00386 
00387 
00396 Error* read_mv_gmm_model(MV_gmm_model** model_out, const char* fname)
00397 {
00398     uint32_t      num_labels;
00399     uint32_t      i;
00400     FILE*         fp;
00401     Vector_u32*   v = NULL;
00402     MV_gmm_model* model;
00403     Error*        e;
00404     int           slen = 256;
00405     char          str[slen];
00406 
00407     num_labels = get_num_haplo_groups();
00408 
00409     if ((fp = fopen(fname, "r")) == NULL)
00410     {
00411         snprintf(str, slen, "%s: %s", fname, strerror(errno));
00412         return JWSC_EARG(str);
00413     }
00414 
00415     if (*model_out == NULL)
00416     {
00417         assert((*model_out = malloc(sizeof(MV_gmm_model))) != NULL);
00418         model = *model_out;
00419         model->priors = NULL;
00420     }
00421 
00422     if (fscanf(fp, "%u\n%u\n", &(model->num_markers), 
00423                 &(model->num_components)) != 2)
00424     {
00425         snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file");
00426         return JWSC_EARG(str);
00427     }
00428 
00429     assert(model->mu    = malloc(num_labels*sizeof(void*)));
00430     assert(model->sigma = malloc(num_labels*sizeof(void*)));
00431     assert(model->pi    = malloc(num_labels*sizeof(void*)));
00432 
00433     if ((e = read_vector_fp_u32(&v, fp)) || (v->num_elts != num_labels))
00434     {
00435         return JWSC_EARG("Improperly formatted model file");
00436     }
00437 
00438     for (i = 0; i < num_labels; i++)
00439     {
00440         model->mu[ i ]    = NULL;
00441         model->sigma[ i ] = NULL;
00442         model->pi[ i ]    = NULL;
00443 
00444         if (v->elts[ i ])
00445         {
00446             if ((e = read_matrix_with_header_fp_d(&(model->mu[i]), fp)) ||
00447                 (e = read_matblock_with_header_fp_d(&(model->sigma[i]), fp)) ||
00448                 (e = read_vector_with_header_fp_d(&(model->pi[i]), fp)))
00449             {
00450                 return JWSC_EARG("Improperly formatted model file");
00451             }
00452         }
00453     }
00454 
00455     if ((e = read_vector_with_header_fp_d(&(model->priors), fp)) != NULL)
00456     {
00457         snprintf(str, slen, "%s: %s", fname, "Improperly formatted model file");
00458         return JWSC_EARG(str);
00459     }
00460 
00461     if (fclose(fp) != 0)
00462     {
00463         snprintf(str, slen, "%s: %s", fname, strerror(errno));
00464         return JWSC_EARG(str);
00465     }
00466 
00467     return NULL;
00468 }
00469 
00470 
00477 Error* write_mv_gmm_model(MV_gmm_model* model, const char* fname)
00478 {
00479     uint32_t num_labels;
00480     uint32_t i;
00481     FILE*    fp;
00482     int      slen = 256;
00483     char     str[slen];
00484 
00485     num_labels = get_num_haplo_groups();
00486 
00487     if ((fp = fopen(fname, "w")) == NULL)
00488     {
00489         snprintf(str, slen, "%s: %s", fname, strerror(errno));
00490         return JWSC_EARG(str);
00491     }
00492 
00493     fprintf(fp, "%u\n", model->num_markers);
00494     fprintf(fp, "%u\n", model->num_components);
00495 
00496     for (i = 0; i < num_labels; i++)
00497     {
00498         fprintf(fp, "%u ", (model->mu[ i ] != NULL));
00499     }
00500     fprintf(fp, "\n");
00501 
00502     for (i = 0; i < num_labels; i++)
00503     {
00504         if (model->mu[ i ])
00505         {
00506             write_matrix_with_header_fp_d(model->mu[ i ], fp);
00507             write_matblock_with_header_fp_d(model->sigma[ i ], fp);
00508             write_vector_with_header_fp_d(model->pi[ i ], fp);
00509         }
00510     }
00511 
00512     write_vector_with_header_fp_d(model->priors, fp);
00513 
00514     if (fclose(fp) != 0)
00515     {
00516         snprintf(str, slen, "%s: %s", fname, strerror(errno));
00517         return JWSC_EARG(str);
00518     }
00519 
00520     return NULL;
00521 }
00522 
00523 
00525 void free_mv_gmm_model(MV_gmm_model* model)
00526 {
00527     uint32_t i;
00528     uint32_t num_labels;
00529 
00530     if (!model)
00531         return;
00532 
00533     num_labels = get_num_haplo_groups();
00534 
00535     for (i = 0; i < num_labels; i++)
00536     {
00537         free_matrix_d(model->mu[ i ]);
00538         free_matblock_d(model->sigma[ i ]);
00539         free_vector_d(model->pi[ i ]);
00540     }
00541     free(model->mu);
00542     free(model->sigma);
00543     free(model->pi);
00544     free_vector_d(model->priors);
00545     free(model);
00546 }
00547 
00548 
00550 static Error* read_mv_gmm_xml_doc
00551 (
00552     xmlDoc**    xml_doc_out,
00553     const char* xml_fname,
00554     const char* dtd_fname
00555 )
00556 {
00557     xmlParserCtxt* xml_parse_ctxt;
00558     xmlValidCtxt*  xml_valid_ctxt;
00559     xmlDtd*        xml_dtd;
00560     int            slen = 256;
00561     char           str[slen];
00562 
00563     assert(xml_parse_ctxt = xmlNewParserCtxt());
00564 
00565     if (!(*xml_doc_out = xmlCtxtReadFile(xml_parse_ctxt, xml_fname, NULL, 0)))
00566     {
00567         snprintf(str, slen, "%s: %s", xml_fname, "Could not parse file");
00568         return JWSC_EARG(str);
00569     } 
00570 
00571     xmlFreeParserCtxt(xml_parse_ctxt);
00572 
00573     if (dtd_fname)
00574     {
00575         assert(xml_valid_ctxt = xmlNewValidCtxt());
00576 
00577         if (!(xml_dtd = xmlParseDTD(NULL, (xmlChar*)dtd_fname)))
00578         {
00579             snprintf(str, slen, "%s: %s", dtd_fname, "Could not parse DTD");
00580             return JWSC_EARG(str);
00581         }
00582 
00583         if (!xmlValidateDtd(xml_valid_ctxt, *xml_doc_out, xml_dtd))
00584         {
00585             snprintf(str, slen, "%s: %s", xml_fname, "XML file not valid");
00586             return JWSC_EARG(str);
00587         }
00588 
00589         xmlFreeValidCtxt(xml_valid_ctxt);
00590         xmlFreeDtd(xml_dtd);
00591     }
00592 
00593     return NULL;
00594 }
00595 
00596 
00603 static Error* create_mv_gmm_model_tree_from_xml_node
00604 (
00605      MV_gmm_model_tree**      tree_out, 
00606      const MV_gmm_model_tree* parent,
00607      uint32_t                 parent_label,
00608      xmlNode*                 xml_node
00609 )
00610 {
00611     MV_gmm_model_tree* tree;
00612     uint32_t i, j;
00613     uint32_t label;
00614     uint32_t altlabel;
00615     uint32_t num_altlabels;
00616     double   p;
00617     xmlAttr* attr;
00618     xmlNode* it;
00619     xmlNode* itt;
00620     Error*   err;
00621     const char* fname;
00622 
00623     if (*tree_out)
00624     {
00625         free_mv_gmm_model_tree(*tree_out);
00626     }
00627 
00628     assert(*tree_out = malloc(sizeof(MV_gmm_model_tree)));
00629     tree = *tree_out;
00630 
00631     tree->parent       = parent;
00632     tree->parent_label = parent_label;
00633     tree->num_groups   = 0;
00634     tree->subtrees     = NULL;
00635     tree->labels       = NULL;
00636     tree->altlabels    = NULL;
00637     tree->priors       = NULL;
00638     tree->model        = NULL;
00639     tree->model_fname  = NULL;
00640 
00641     assert(XMLStrEqual(xml_node->name, "model"));
00642 
00643     for (it = xml_node->children; it; it = it->next)
00644     {
00645         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "group"))
00646         {
00647             tree->num_groups++;
00648         }
00649     }
00650 
00651     assert(tree->subtrees = calloc(tree->num_groups, sizeof(void*)));
00652     assert(tree->altlabels = calloc(tree->num_groups, sizeof(void*)));
00653     create_vector_u32(&(tree->labels), tree->num_groups);
00654 
00655     for (attr = xml_node->properties; attr; attr = attr->next)
00656     {
00657         if (XMLStrEqual(attr->name, "priors") && 
00658                 XMLStrEqual(attr->children->content, "true"))
00659         {
00660             create_init_vector_d(&(tree->priors), get_num_haplo_groups(), 1.0);
00661             break;
00662         }
00663     }
00664 
00665     i = 0;
00666     for (it = xml_node->children; it; it = it->next)
00667     {
00668         j = 0;
00669         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "file"))
00670         {
00671             fname = (const char*) it->children->content;
00672             tree->model_fname = malloc((strlen(fname)+1) * sizeof(char));
00673             strcpy(tree->model_fname, fname);
00674         }
00675         else if (it->type == XML_ELEMENT_NODE && 
00676                 XMLStrEqual(it->name, "mixture-components"))
00677         {
00678             if (sscanf((char*)it->children->content, "%d", 
00679                         &(tree->num_components)) != 1 || 
00680                     tree->num_components < 1)
00681             {
00682                 return JWSC_EARG("Invalid number of mixture components");
00683             }
00684         }
00685         else if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "group"))
00686         {
00687             num_altlabels = 0;
00688             for (itt = it->children; itt; itt = itt->next)
00689             {
00690                 if (itt->type == XML_ELEMENT_NODE &&
00691                         XMLStrEqual(itt->name, "altlabel"))
00692                 {
00693                     num_altlabels++;
00694                 }
00695             }
00696             if (num_altlabels)
00697             {
00698                 create_vector_u32(&(tree->altlabels[ i ]), num_altlabels);
00699             }
00700 
00701             for (itt = it->children; itt; itt = itt->next)
00702             {
00703                 if (itt->type == XML_ELEMENT_NODE &&
00704                         XMLStrEqual(itt->name, "label"))
00705                 {
00706                     if ((err = lookup_haplo_group_index_from_label(&label, 
00707                                     (const char*) itt->children->content)))
00708                     {
00709                         return err;
00710                     }
00711                     tree->labels->elts[ i ] = label;
00712                 }
00713                 else if (itt->type == XML_ELEMENT_NODE &&
00714                         XMLStrEqual(itt->name, "altlabel"))
00715                 {
00716                     if ((err = lookup_haplo_group_index_from_label(&altlabel, 
00717                                     (const char*) itt->children->content)))
00718                     {
00719                         return err;
00720                     }
00721                     tree->altlabels[ i ]->elts[ j++ ] = altlabel;
00722                 }
00723                 else if (tree->priors && itt->type == XML_ELEMENT_NODE &&
00724                         XMLStrEqual(itt->name, "prior"))
00725                 {
00726                     if (sscanf((char*)itt->children->content, "%lf", &p) != 1 ||
00727                             p < 0 || p > 1)
00728                     {
00729                         return JWSC_EARG("Invalid prior");
00730                     }
00731                     tree->priors->elts[ i ] = p;
00732                 }
00733                 else if (itt->type == XML_ELEMENT_NODE &&
00734                         XMLStrEqual(itt->name, "model"))
00735                 {
00736                     if ((err = create_mv_gmm_model_tree_from_xml_node(
00737                                  &(tree->subtrees[ i ]), tree, label, itt)))
00738                     {
00739                         return err;
00740                     }
00741                 }
00742             }
00743             i++;
00744         }
00745     }
00746     assert(i == tree->num_groups);
00747 
00748     return NULL;
00749 }
00750 
00751 
00756 static Error* create_model_training_data
00757 (
00758     Vector_u32**      train_labels_out,
00759     Matrix_i32**      train_markers_out,
00760     const Vector_u32* data_labels,
00761     const Matrix_i32* data_markers,
00762     const Vector_u32* model_labels,
00763     Vector_u32*const* model_altlabels
00764 )
00765 {
00766     uint8_t     b;
00767     uint32_t    i, j, k;
00768     uint32_t    n;
00769     Vector_u32* train_labels = NULL;
00770     Matrix_i32* train_markers = NULL;
00771 
00772     copy_vector_u32(&train_labels, data_labels);
00773     copy_matrix_i32(&train_markers, data_markers);
00774 
00775     n = 0;
00776     for (i = 0; i < data_labels->num_elts; i++)
00777     {
00778         for (j = 0; j < model_labels->num_elts; j++)
00779         {
00780             if (is_ancestor(data_labels->elts[ i ], model_labels->elts[ j ]))
00781             {
00782                 train_labels->elts[ n ] = model_labels->elts[ j ];
00783                 copy_matrix_block_into_matrix_i32(train_markers, n, 0,
00784                         data_markers, i, 0, 1, data_markers->num_cols);
00785                 n++;
00786                 break;
00787             }
00788             else if (model_altlabels[ j ])
00789             {
00790                 for (k = 0; k < model_altlabels[ j ]->num_elts; k++)
00791                 {
00792                     if ((b = is_ancestor(data_labels->elts[ i ], 
00793                                 model_altlabels[ j ]->elts[ k ])))
00794                     {
00795                         train_labels->elts[ n ] = model_labels->elts[ j ];
00796                         copy_matrix_block_into_matrix_i32(train_markers, n, 0,
00797                                 data_markers, i, 0, 1, data_markers->num_cols);
00798                         n++;
00799                         break;
00800                     }
00801                 }
00802                 if (b)
00803                 {
00804                     break;
00805                 }
00806             }
00807         }
00808     }
00809 
00810     if (!n)
00811     {
00812         return JWSC_EARG("No data for model");
00813     }
00814 
00815     copy_vector_section_u32(train_labels_out, train_labels, 0, n);
00816     copy_matrix_block_i32(train_markers_out, train_markers, 0, 0, n, 
00817             train_markers->num_cols);
00818 
00819     free_vector_u32(train_labels);
00820     free_matrix_i32(train_markers);
00821 
00822     return NULL;
00823 }
00824 
00825 
00827 static Error* train_mv_gmm_model_node
00828 (
00829     MV_gmm_model_node* node,
00830     const Vector_u32*  labels, 
00831     const Matrix_i32*  markers
00832 )
00833 {
00834     uint32_t i;
00835     Vector_u32* node_labels  = NULL;
00836     Matrix_i32* node_markers = NULL;
00837     Error*  err;
00838     int     slen = 256;
00839     char    str[slen];
00840 
00841     if ((err = create_model_training_data(&node_labels, &node_markers, labels,
00842                     markers, node->labels, node->altlabels)))
00843     {
00844         snprintf(str, slen, "%s: %s", node->model_fname, err->msg);
00845         return JWSC_EARG(str);
00846     }
00847 
00848     train_mv_gmm_model(&(node->model), node_labels, node_markers, 
00849             node->priors, node->num_components);
00850 
00851     for (i = 0; i < node->num_groups; i++)
00852     {
00853         if (node->subtrees[ i ])
00854         {
00855             if ((err = train_mv_gmm_model_node(node->subtrees[ i ], labels,
00856                         markers)))
00857             {
00858                 return err;
00859             }
00860         }
00861     }
00862 
00863     free_vector_u32(node_labels);
00864     free_matrix_i32(node_markers);
00865 
00866     return NULL;
00867 }
00868 
00869 
00877 Error* train_mv_gmm_model_tree
00878 (
00879     MV_gmm_model_tree** tree_out,
00880     const Vector_u32*   labels, 
00881     const Matrix_i32*   markers,
00882     const char*         tree_xml_fname,
00883     const char*         tree_dtd_fname
00884 )
00885 {
00886     xmlDoc*  xml_doc;
00887     xmlNode* xml_node;
00888     xmlNode* it;
00889     Error*   err;
00890     int      slen = 256;
00891     char     str[slen];
00892 
00893     assert(tree_xml_fname);
00894 
00895     if ((err = read_mv_gmm_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname)))
00896     {
00897         return err;
00898     }
00899 
00900     for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next)
00901     {
00902         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model"))
00903         {
00904             xml_node = it;
00905             break;
00906         }
00907     }
00908 
00909     if ((err = create_mv_gmm_model_tree_from_xml_node(tree_out, NULL, 0, 
00910                     xml_node)))
00911     {
00912         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
00913         return JWSC_EARG(str);
00914     }
00915 
00916     if ((err = train_mv_gmm_model_node(*tree_out, labels, markers)))
00917     {
00918         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
00919         return JWSC_EARG(str);
00920     }
00921 
00922     return NULL;
00923 }
00924 
00925 
00927 static Error* recursively_predict_label_in_model_tree
00928 (
00929     uint32_t*                label_out,
00930     double*                  confidence_out,
00931     const MV_gmm_model_tree* tree,
00932     const Vector_i32*        markers_v,
00933     uint32_t                 order
00934 )
00935 {
00936     uint32_t subtree_label;
00937     uint32_t i;
00938     double   confidence;
00939     Error*   err;
00940 
00941     if ((err = predict_label_with_mv_gmm_model(label_out, &confidence, 
00942                     markers_v, tree->model, order)) != NULL)
00943     {
00944         return err;
00945     }
00946     *confidence_out *= confidence;
00947 
00948     for (i = 0; i < tree->num_groups; i++)
00949     {
00950         if (tree->subtrees[ i ])
00951         {
00952             subtree_label = tree->subtrees[ i ]->parent_label;
00953 
00954             if (subtree_label == *label_out)
00955             {
00956                 if ((err = recursively_predict_label_in_model_tree(label_out,
00957                                 confidence_out, tree->subtrees[ i ], markers_v,
00958                                 order)) != NULL)
00959                 {
00960                     return err;
00961                 }
00962             }
00963         }
00964     }
00965 
00966     return NULL;
00967 }
00968 
00969 
00985 Error* predict_labels_with_mv_gmm_model_tree
00986 (
00987     Vector_u32**             labels_out,
00988     Vector_d**               confidence_out,
00989     const Matrix_i32*        markers,
00990     const MV_gmm_model_tree* tree,
00991     uint32_t                 order
00992 )
00993 {
00994     uint32_t s, num_samples;
00995     uint32_t m, num_markers;
00996 
00997     Vector_u32* labels;
00998     Vector_d*   confidence;
00999     Vector_i32* markers_v;
01000     Error*      err;
01001 
01002     num_samples = markers->num_rows;
01003     num_markers = markers->num_cols;
01004 
01005     create_vector_u32(labels_out, num_samples);
01006     labels = *labels_out;
01007 
01008     create_init_vector_d(confidence_out, num_samples, 1.0);
01009     confidence = *confidence_out;
01010 
01011     markers_v = NULL;
01012     create_vector_i32(&markers_v, num_markers);
01013 
01014     for (s = 0; s < num_samples; s++)
01015     {
01016         for (m = 0; m < num_markers; m++)
01017         {
01018             markers_v->elts[ m ]= markers->elts[ s ][ m ];
01019         }
01020 
01021         if ((err = recursively_predict_label_in_model_tree(
01022                         &(labels->elts[ s ]), 
01023                         &(confidence->elts[ s ]), tree, markers_v, order)) 
01024                 != NULL)
01025         {
01026             return err;
01027         }
01028     }
01029 
01030     free_vector_i32(markers_v);
01031 
01032     return NULL;
01033 }
01034 
01035 
01037 static Error* read_mv_gmm_model_node
01038 (
01039      MV_gmm_model_node* node, 
01040      const char*        model_dirname
01041 )
01042 {
01043     uint32_t    i;
01044     char        buf[1024] = {0};
01045     Error*      err;
01046 
01047     snprintf(buf, 1024, "%s/%s", model_dirname, node->model_fname);
01048     if ((err = read_mv_gmm_model(&(node->model), buf)))
01049     {
01050         return err;
01051     }
01052 
01053     for (i = 0; i < node->num_groups; i++)
01054     {
01055         if (node->subtrees[ i ])
01056         {
01057             if ((err = read_mv_gmm_model_node(node->subtrees[ i ], 
01058                             model_dirname)))
01059             {
01060                 return err;
01061             }
01062         }
01063     }
01064 
01065     return NULL;
01066 }
01067 
01068 
01075 Error* read_mv_gmm_model_tree
01076 (
01077     MV_gmm_model_tree** tree_out,
01078     const char*         tree_xml_fname,
01079     const char*         tree_dtd_fname,
01080     const char*         model_dirname
01081 )
01082 {
01083     xmlDoc*  xml_doc;
01084     xmlNode* xml_node;
01085     xmlNode* it;
01086     Error*   err;
01087     int      slen = 256;
01088     char     str[slen];
01089 
01090     assert(tree_xml_fname && model_dirname);
01091 
01092     if ((err = read_mv_gmm_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname)))
01093     {
01094         return err;
01095     }
01096 
01097     for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next)
01098     {
01099         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model"))
01100         {
01101             xml_node = it;
01102             break;
01103         }
01104     }
01105 
01106     if ((err = create_mv_gmm_model_tree_from_xml_node(tree_out, NULL, 0,
01107                     xml_node)))
01108     {
01109         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
01110         return JWSC_EARG(str);
01111     }
01112 
01113     if ((err = read_mv_gmm_model_node(*tree_out, model_dirname)))
01114     {
01115         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
01116         return JWSC_EARG(str);
01117     }
01118 
01119     return NULL;
01120 }
01121 
01122 
01127 Error* write_mv_gmm_model_tree
01128 (
01129     const MV_gmm_model_tree* tree,
01130     const char*              model_dirname
01131 )
01132 {
01133     uint32_t i;
01134     char     buf[1024] = {0};
01135     Error*   err;
01136     int      slen = 256;
01137     char     str[slen];
01138 
01139     snprintf(buf, 1024, "%s/%s", model_dirname, tree->model_fname);
01140     if ((err = write_mv_gmm_model(tree->model, buf)))
01141     {
01142         return err;
01143     }
01144 
01145     for (i = 0; i < tree->num_groups; i++)
01146     {
01147         if (tree->subtrees[ i ])
01148         {
01149             if ((err = write_mv_gmm_model_tree(tree->subtrees[ i ], 
01150                             model_dirname)))
01151             {
01152                 snprintf(str, slen, "%s: %s", tree->model_fname, err->msg);
01153                 return JWSC_EARG(str);
01154             }
01155         }
01156     }
01157 
01158     return NULL;
01159 }
01160 
01161 
01163 void free_mv_gmm_model_tree(MV_gmm_model_tree* tree)
01164 {
01165     uint32_t i;
01166 
01167     if (!tree)
01168         return;
01169 
01170     for (i = 0; i < tree->num_groups; i++)
01171     {
01172         free_mv_gmm_model_tree(tree->subtrees[ i ]);
01173         free_vector_u32(tree->altlabels[ i ]);
01174     }
01175 
01176     free(tree->subtrees);
01177     free(tree->altlabels);
01178     free_vector_u32(tree->labels);
01179     free_vector_d(tree->priors);
01180     free_mv_gmm_model(tree->model);
01181     free(tree->model_fname);
01182     free(tree);
01183 }