Haplo Prediction
predict haplogroups
svm_tree.c
Go to the documentation of this file.
00001 /*
00002  * This work is licensed under a Creative Commons 
00003  * Attribution-Noncommercial-Share Alike 3.0 United States License.
00004  * 
00005  *    http://creativecommons.org/licenses/by-nc-sa/3.0/us/
00006  * 
00007  * You are free:
00008  * 
00009  *    to Share - to copy, distribute, display, and perform the work
00010  *    to Remix - to make derivative works
00011  * 
00012  * Under the following conditions:
00013  * 
00014  *    Attribution. You must attribute the work in the manner specified by the
00015  *    author or licensor (but not in any way that suggests that they endorse you
00016  *    or your use of the work).
00017  * 
00018  *    Noncommercial. You may not use this work for commercial purposes.
00019  * 
00020  *    Share Alike. If you alter, transform, or build upon this work, you may
00021  *    distribute the resulting work only under the same or similar license to
00022  *    this one.
00023  * 
00024  * For any reuse or distribution, you must make clear to others the license
00025  * terms of this work. The best way to do this is by including this header.
00026  * 
00027  * Any of the above conditions can be waived if you get permission from the
00028  * copyright holder.
00029  * 
00030  * Apart from the remix rights granted under this license, nothing in this
00031  * license impairs or restricts the author's moral rights.
00032  */
00033 
00034 
00046 #include <config.h>
00047 
00048 #include <stdlib.h>
00049 #include <stdio.h>
00050 #include <inttypes.h>
00051 #include <assert.h>
00052 #include <string.h>
00053 #include <errno.h>
00054 #include <math.h>
00055 
00056 #include <libxml/tree.h>
00057 #include <libxml/parser.h>
00058 #include <libxml/valid.h>
00059 
00060 #if defined HAPLO_HAVE_LIBSVM_H
00061 #include <libsvm.h>
00062 #elif defined HAPLO_HAVE_SVM_H
00063 #include <svm.h>
00064 #endif
00065 
00066 #ifdef HAPLO_HAVE_DMALLOC
00067 #include <dmalloc.h>
00068 #endif
00069 
00070 #include <jwsc/base/error.h>
00071 #include <jwsc/base/limits.h>
00072 #include <jwsc/base/file_io.h>
00073 #include <jwsc/vector/vector.h>
00074 #include <jwsc/vector/vector_io.h>
00075 #include <jwsc/vector/vector_math.h>
00076 #include <jwsc/matrix/matrix.h>
00077 #include <jwsc/matrix/matrix_io.h>
00078 #include <jwsc/matblock/matblock.h>
00079 #include <jwsc/matblock/matblock_io.h>
00080 #include <jwsc/stat/gmm.h>
00081 
00082 #include "xml.h"
00083 #include "haplo_groups.h"
00084 #include "svm_tree.h"
00085 
00086 
00097 void train_svm_model
00098 (
00099     SVM_model**        model_out,
00100     const Vector_u32*  labels,
00101     const Matrix_i32*  markers,
00102     double             cost,
00103     double             gamma
00104 )
00105 {
00106     uint32_t s, num_samples;
00107     uint32_t m, num_markers;
00108     int32_t  marker_val;
00109 
00110     SVM_model*            model;
00111     struct svm_parameter* param;
00112 
00113     if (*model_out != NULL)
00114     {
00115         free_svm_model(*model_out);
00116     }
00117 
00118     assert(*model_out = malloc(sizeof(SVM_model)));
00119     model = *model_out;
00120 
00121     num_samples = markers->num_rows;
00122     num_markers = markers->num_cols;
00123 
00124     assert(model->prob = malloc(sizeof(struct svm_problem)));
00125     assert(model->prob->y = malloc(num_samples*sizeof(double)));
00126     assert(model->prob->x = malloc(num_samples*sizeof(struct svm_node*)));
00127 
00128     model->prob->l = num_samples;
00129     for (s = 0; s < num_samples; s++)
00130     {
00131         model->prob->y[ s ] = labels->elts[ s ];
00132         model->prob->x[ s ] = malloc((num_markers+1)*sizeof(struct svm_node));
00133         assert(model->prob->x[ s ]);
00134 
00135         for (m = 0; m < num_markers; m++)
00136         {
00137             model->prob->x[ s ][ m ].index = m+1;
00138             marker_val = markers->elts[ s ][ m ];
00139             model->prob->x[ s ][ m ].value = marker_val;
00140         }
00141         model->prob->x[ s ][ m ].index = -1;
00142     }
00143 
00144     assert(param = malloc(sizeof(struct svm_parameter)));
00145 
00146     param->svm_type = C_SVC;
00147     param->kernel_type = RBF;
00148     param->gamma = gamma;
00149     param->C = cost;
00150     param->nr_weight = 0;
00151     param->probability = 1;
00152     param->shrinking = 1;       /* svm-train default */
00153     param->eps = 0.001;         /* svm-train default */
00154     param->cache_size = 40;     /* svm-train default */
00155 
00156     assert(model->svm = svm_train(model->prob, param));
00157 
00158     free(param);
00159 }
00160 
00161 
00171 Error* predict_label_with_svm_model
00172 (
00173     uint32_t*               label_out,
00174     double**                confidence_out,
00175     const struct svm_node*  markers,
00176     const SVM_model*        model
00177 )
00178 {
00179     assert(*confidence_out);
00180     *label_out = svm_predict_probability(model->svm, markers, *confidence_out);
00181 
00182     return NULL;
00183 }
00184 
00185 
00187 static void get_predicted_labels_confidence
00188 (
00189     Vector_d**        confidence_out,
00190     const Vector_u32* labels_v,
00191     const Matrix_d*   confidence_matrix, 
00192     const SVM_model*  model
00193 )
00194 {
00195     uint32_t  s, num_samples;
00196     uint32_t  l, num_labels;
00197     uint32_t  label;
00198     int*      labels;
00199     Vector_d* confidence;
00200 
00201     num_samples = labels_v->num_elts;
00202 
00203     create_zero_vector_d(confidence_out, num_samples);
00204     confidence = *confidence_out;
00205 
00206     num_labels = (uint32_t)svm_get_nr_class(model->svm);
00207     assert(labels = malloc(num_labels*sizeof(uint32_t)));
00208     svm_get_labels(model->svm, labels);
00209 
00210     for (s = 0; s < num_samples; s++)
00211     {
00212         label = labels_v->elts[ s ];
00213 
00214         for (l = 0; l < num_labels; l++)
00215         {
00216             if ((uint32_t)labels[ l ] == label)
00217             {
00218                 confidence->elts[ s ] = confidence_matrix->elts[ s ][ l ];
00219             }
00220         }
00221     }
00222 
00223     free(labels);
00224 }
00225 
00226 
00240 Error* predict_labels_with_svm_model
00241 (
00242     Vector_u32**      labels_out,
00243     Vector_d**        confidence_out,
00244     const Matrix_i32* markers,
00245     const SVM_model*  model
00246 )
00247 {
00248     uint32_t s, num_samples;
00249     uint32_t m, num_markers;
00250 
00251     Vector_u32* labels;
00252     Matrix_d*   confidence_matrix;
00253 
00254     struct svm_node* markers_v;
00255 
00256     num_samples = markers->num_rows;
00257     num_markers = markers->num_cols;
00258 
00259     assert(markers_v = malloc((num_markers+1)*sizeof(struct svm_node)));
00260 
00261     create_zero_vector_u32(labels_out, num_samples);
00262     labels = *labels_out;
00263 
00264     confidence_matrix = NULL;
00265     create_zero_matrix_d(&confidence_matrix, num_samples, 
00266             svm_get_nr_class(model->svm));
00267 
00268     for (s = 0; s < num_samples; s++)
00269     {
00270         for (m = 0; m < num_markers; m++)
00271         {
00272             markers_v[ m ].index = m+1;
00273             markers_v[ m ].value = markers->elts[ s ][ m ];
00274         }
00275         markers_v[ m].index = -1;
00276 
00277         predict_label_with_svm_model(&(labels->elts[ s ]),
00278                 &(confidence_matrix->elts[ s ]), markers_v, model);
00279     }
00280 
00281     get_predicted_labels_confidence(confidence_out, labels, confidence_matrix, 
00282             model);
00283 
00284     free(markers_v);
00285     free_matrix_d(confidence_matrix);
00286 
00287     return NULL;
00288 }
00289 
00290 
00299 Error* read_svm_model(SVM_model** model_out, const char* fname)
00300 {
00301     int  slen = 256;
00302     char str[slen];
00303 
00304     if (*model_out)
00305     {
00306         free_svm_model(*model_out);
00307     }
00308 
00309     assert(*model_out = malloc(sizeof(SVM_model)));
00310     (*model_out)->prob = NULL;
00311 
00312     if (!((*model_out)->svm = svm_load_model(fname)))
00313     {
00314         snprintf(str, slen, "%s: %s", fname, "Could not read model");
00315         return JWSC_EARG(str);
00316     }
00317 
00318     return NULL;
00319 }
00320 
00321 
00328 Error* write_svm_model(SVM_model* model, const char* fname)
00329 {
00330     int  slen = 256;
00331     char str[slen];
00332 
00333     if (svm_save_model(fname, model->svm) < 0)
00334     {
00335         snprintf(str, slen, "%s: %s", fname, "Could not write model");
00336         return JWSC_EARG(str);
00337     }
00338 
00339     return NULL;
00340 }
00341 
00342 
00356 Error* write_svm_model_training_data
00357 (
00358     const Vector_u32* labels,
00359     const Matrix_i32* markers,
00360     const char*       fname
00361 )
00362 {
00363     FILE*    fp;
00364     uint32_t s, num_samples;
00365     uint32_t m, num_markers;
00366     int  slen = 256;
00367     char str[slen];
00368 
00369     num_samples = markers->num_rows;
00370     num_markers = markers->num_cols;
00371 
00372     if ((fp = fopen(fname, "w")) == NULL)
00373     {
00374         snprintf(str, slen, "%s: %s", fname, strerror(errno));
00375         return JWSC_EIO(str);
00376     }
00377 
00378     for (s = 0; s < num_samples; s++)
00379     {
00380         fprintf(fp, "%-4d", labels->elts[ s ]);
00381         for (m = 0; m < num_markers; m++)
00382         {
00383             fprintf(fp, " %2d:%-4d", m+1, markers->elts[ s ][ m ]);
00384         }
00385         fprintf(fp, "\n");
00386     }
00387 
00388     if (fclose(fp) != 0)
00389     {
00390         snprintf(str, slen, "%s: %s", fname, strerror(errno));
00391         return JWSC_EIO(str);
00392     }
00393 
00394     return NULL;
00395 }
00396 
00397 
00399 void free_svm_model(SVM_model* model)
00400 {
00401     uint32_t i;
00402 
00403     if (!model)
00404         return;
00405 
00406     svm_free_and_destroy_model(&(model->svm));
00407 
00408     if (model->prob)
00409     {
00410         for (i = 0; i < model->prob->l; i++ )
00411         {
00412             free(model->prob->x[ i ]);
00413         }
00414         free(model->prob->x);
00415         free(model->prob->y);
00416         free(model->prob);
00417     }
00418 
00419     free(model);
00420 }
00421 
00422 
00424 static Error* read_svm_xml_doc
00425 (
00426     xmlDoc**    xml_doc_out,
00427     const char* xml_fname,
00428     const char* dtd_fname
00429 )
00430 {
00431     xmlParserCtxt* xml_parse_ctxt;
00432     xmlValidCtxt*  xml_valid_ctxt;
00433     xmlDtd*        xml_dtd;
00434     int            slen = 256;
00435     char           str[slen];
00436 
00437     assert(xml_parse_ctxt = xmlNewParserCtxt());
00438 
00439     if (!(*xml_doc_out = xmlCtxtReadFile(xml_parse_ctxt, xml_fname, NULL, 0)))
00440     {
00441         snprintf(str, slen, "%s: %s", xml_fname, "Could not parse file");
00442         return JWSC_EARG(str);
00443     } 
00444 
00445     xmlFreeParserCtxt(xml_parse_ctxt);
00446 
00447     if (dtd_fname)
00448     {
00449         assert(xml_valid_ctxt = xmlNewValidCtxt());
00450 
00451         if (!(xml_dtd = xmlParseDTD(NULL, (xmlChar*)dtd_fname)))
00452         {
00453             snprintf(str, slen, "%s: %s", dtd_fname, "Could not parse DTD");
00454             return JWSC_EARG(str);
00455         }
00456 
00457         if (!xmlValidateDtd(xml_valid_ctxt, *xml_doc_out, xml_dtd))
00458         {
00459             snprintf(str, slen, "%s: %s", xml_fname, "XML file not valid");
00460             return JWSC_EARG(str);
00461         }
00462 
00463         xmlFreeValidCtxt(xml_valid_ctxt);
00464         xmlFreeDtd(xml_dtd);
00465     }
00466 
00467     return NULL;
00468 }
00469 
00470 static Error* create_svm_model_tree_from_xml_node
00471 (
00472      SVM_model_tree**      tree_out, 
00473      const SVM_model_tree* parent,
00474      uint32_t              parent_label,
00475      xmlNode*              xml_node
00476 );
00477 
00479 static Error* create_svm_model_node_from_xml_node
00480 (
00481      xmlNode*        xml_node,
00482      SVM_model_node* svm_node,
00483      uint32_t        i
00484 )
00485 {
00486     uint32_t    label;
00487     uint32_t    altlabel;
00488     uint32_t    num_altlabels;
00489     uint32_t    g;
00490     uint32_t    a;
00491     const char* fname;
00492     xmlNode*    it;
00493     xmlNode*    itt;
00494     Error*      err;
00495     int         slen = 256;
00496     char        str[slen];
00497 
00498     g = 0;
00499     for (it = xml_node; it; it = it->next)
00500     {
00501         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "file"))
00502         {
00503             fname = (const char*) it->children->content;
00504             svm_node->model_fnames[ i ] = malloc((strlen(fname)+1) * 
00505                     sizeof(char));
00506             strcpy(svm_node->model_fnames[ i ], fname);
00507         }
00508         else if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "cost"))
00509         {
00510             if (sscanf((char*)it->children->content, "%lf", 
00511                         &(svm_node->cost->elts[ i ])) != 1)
00512             {
00513                 snprintf(str, slen, "%s: %s", fname, "Invalid cost");
00514                 return JWSC_EARG(str);
00515             }
00516         }
00517         else if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "gamma"))
00518         {
00519             if (sscanf((char*)it->children->content, "%lf", 
00520                         &(svm_node->gamma->elts[ i ])) != 1)
00521             {
00522                 snprintf(str, slen, "%s: %s", fname, "Invalid gamma");
00523                 return JWSC_EARG(str);
00524             }
00525         }
00526         else if (it->type == XML_ELEMENT_NODE && 
00527                 (XMLStrEqual(it->name, "group-one") ||
00528                  XMLStrEqual(it->name, "group-all")))
00529         {
00530             a = 0;
00531             num_altlabels = 0;
00532             for (itt = it->children; itt; itt = itt->next)
00533             {
00534                 if (itt->type == XML_ELEMENT_NODE &&
00535                         XMLStrEqual(itt->name, "altlabel"))
00536                 {
00537                     num_altlabels++;
00538                 }
00539             }
00540             if (num_altlabels)
00541             {
00542                 create_vector_u32(&(svm_node->altlabels[ g ][ i ]), 
00543                         num_altlabels);
00544             }
00545 
00546             for (itt = it->children; itt; itt = itt->next)
00547             {
00548                 if (itt->type == XML_ELEMENT_NODE &&
00549                         XMLStrEqual(itt->name, "label"))
00550                 {
00551                     if ((err = lookup_haplo_group_index_from_label(&label, 
00552                                     (const char*) itt->children->content)))
00553                     {
00554                         snprintf(str, slen, "%s: %s", fname, err->msg);
00555                         return JWSC_EARG(str);
00556                     }
00557                     svm_node->labels[ g ]->elts[ i ] = label;
00558                 }
00559                 else if (itt->type == XML_ELEMENT_NODE &&
00560                         XMLStrEqual(itt->name, "altlabel"))
00561                 {
00562                     if ((err = lookup_haplo_group_index_from_label(&altlabel, 
00563                                     (const char*) itt->children->content)))
00564                     {
00565                         snprintf(str, slen, "%s: %s", fname, err->msg);
00566                         return JWSC_EARG(str);
00567                     }
00568                     svm_node->altlabels[ g ][ i ]->elts[ a++ ] = altlabel;
00569                 }
00570                 else if (itt->type == XML_ELEMENT_NODE && 
00571                          (XMLStrEqual(itt->name, "binary-model") ||
00572                           XMLStrEqual(itt->name, "one-vs-all-model") ||
00573                           XMLStrEqual(itt->name, "one-vs-one-model")))
00574                 {
00575                     if ((err = create_svm_model_tree_from_xml_node(
00576                                     &(svm_node->subtrees[ g ][ i ]), svm_node,
00577                                     label, itt)))
00578                     {
00579                         snprintf(str, slen, "%s: %s", fname, err->msg);
00580                         return JWSC_EARG(str);
00581                     }
00582                     break;
00583                 }
00584             }
00585             g++;
00586         }
00587     }
00588 
00589     return NULL;
00590 }
00591 
00592 
00599 static Error* create_svm_model_tree_from_xml_node
00600 (
00601      SVM_model_tree**      tree_out, 
00602      const SVM_model_tree* parent,
00603      uint32_t              parent_label,
00604      xmlNode*              xml_node
00605 )
00606 {
00607     SVM_model_tree* tree;
00608     uint32_t i;
00609     xmlNode* it;
00610     Error*   err;
00611 
00612     if (*tree_out)
00613     {
00614         free_svm_model_tree(*tree_out);
00615     }
00616 
00617     assert(*tree_out = malloc(sizeof(SVM_model_tree)));
00618     tree = *tree_out;
00619 
00620     tree->parent       = parent;
00621     tree->parent_label = parent_label;
00622     tree->subtrees[0]  = NULL;
00623     tree->subtrees[1]  = NULL;
00624     tree->labels[0]    = NULL;
00625     tree->labels[1]    = NULL;
00626     tree->altlabels[0] = NULL;
00627     tree->altlabels[1] = NULL;
00628     tree->cost         = NULL;
00629     tree->gamma        = NULL;
00630     tree->num_models   = 0;
00631     tree->models       = NULL;
00632     tree->model_fnames = NULL;
00633 
00634     assert(XMLStrEqual(xml_node->name, "binary-model") ||
00635            XMLStrEqual(xml_node->name, "one-vs-all-model") ||
00636            XMLStrEqual(xml_node->name, "one-vs-one-model"));
00637 
00638     for (it = xml_node; it; it = it->next)
00639     {
00640         if (it->type == XML_ELEMENT_NODE && 
00641                 (XMLStrEqual(it->name, "binary-model") ||
00642                  XMLStrEqual(it->name, "one-vs-all-model") ||
00643                  XMLStrEqual(it->name, "one-vs-one-model")))
00644         {
00645             tree->num_models++;
00646         }
00647     }
00648 
00649     assert(tree->subtrees[0] = calloc(tree->num_models, sizeof(void*)));
00650     assert(tree->subtrees[1] = calloc(tree->num_models, sizeof(void*)));
00651     assert(tree->altlabels[0] = calloc(tree->num_models, sizeof(void*)));
00652     assert(tree->altlabels[1] = calloc(tree->num_models, sizeof(void*)));
00653     assert(tree->model_fnames = calloc(tree->num_models, sizeof(void*)));
00654     assert(tree->models = calloc(tree->num_models, sizeof(void*)));
00655 
00656     create_vector_u32(&(tree->labels[0]), tree->num_models);
00657     create_vector_u32(&(tree->labels[1]), tree->num_models);
00658     create_vector_d(&(tree->cost), tree->num_models);
00659     create_vector_d(&(tree->gamma), tree->num_models);
00660 
00661     i = 0;
00662     for (it = xml_node; it; it = it->next)
00663     {
00664         if (it->type == XML_ELEMENT_NODE && 
00665                 (XMLStrEqual(it->name, "binary-model") ||
00666                  XMLStrEqual(it->name, "one-vs-all-model") ||
00667                  XMLStrEqual(it->name, "one-vs-one-model")))
00668         {
00669             if ((err = create_svm_model_node_from_xml_node(it->children, tree, 
00670                         i++)))
00671             {
00672                 return err;
00673             }
00674         }
00675     }
00676     assert(i == tree->num_models);
00677 
00678     return NULL;
00679 }
00680 
00681 
00686 static Error* create_model_training_data
00687 (
00688     Vector_u32**       train_labels_out,
00689     Matrix_i32**       train_markers_out,
00690     const Vector_u32*  data_labels,
00691     const Matrix_i32*  data_markers,
00692     uint32_t*          model_labels,
00693     const Vector_u32** model_altlabels
00694 )
00695 {
00696     uint8_t     b;
00697     uint32_t    i, j, k;
00698     uint32_t    n;
00699     uint32_t    m[2] = {0};
00700     const char* label_str;
00701     Vector_u32* train_labels = NULL;
00702     Matrix_i32* train_markers = NULL;
00703     int         slen = 256;
00704     char        str[slen];
00705 
00706     copy_vector_u32(&train_labels, data_labels);
00707     copy_matrix_i32(&train_markers, data_markers);
00708 
00709     n = 0;
00710     for (i = 0; i < data_labels->num_elts; i++)
00711     {
00712         for (j = 0; j < 2; j++)
00713         {
00714             if (is_ancestor(data_labels->elts[ i ], model_labels[ j ]))
00715             {
00716                 train_labels->elts[ n ] = model_labels[ j ];
00717                 copy_matrix_block_into_matrix_i32(train_markers, n, 0,
00718                         data_markers, i, 0, 1, data_markers->num_cols);
00719                 n++;
00720                 m[j]++;
00721                 break;
00722             }
00723             else if (model_altlabels[ j ])
00724             {
00725                 for (k = 0; k < model_altlabels[ j ]->num_elts; k++)
00726                 {
00727                     if ((b = is_ancestor(data_labels->elts[ i ], 
00728                                     model_altlabels[ j ]->elts[ k ])))
00729                     {
00730                         train_labels->elts[ n ] = model_labels[ j ];
00731                         copy_matrix_block_into_matrix_i32(train_markers, n, 0,
00732                                 data_markers, i, 0, 1, data_markers->num_cols);
00733                         n++;
00734                         m[j]++;
00735                         break;
00736                     }
00737                 }
00738                 if (b)
00739                 {
00740                     break;
00741                 }
00742             }
00743         }
00744     }
00745 
00746     assert(n == (m[0] + m[1]));
00747 
00748     if (!(m[j=0]) || !(m[j=1]))
00749     {
00750         lookup_haplo_group_label_from_index(&label_str, model_labels[j]);
00751         snprintf(str, slen, "No data for model label %s", label_str);
00752         return JWSC_EARG(str);
00753     }
00754 
00755     copy_vector_section_u32(train_labels_out, train_labels, 0, n);
00756     copy_matrix_block_i32(train_markers_out, train_markers, 0, 0, n, 
00757             train_markers->num_cols);
00758 
00759     free_vector_u32(train_labels);
00760     free_matrix_i32(train_markers);
00761 
00762     return NULL;
00763 }
00764 
00765 
00767 static Error* train_svm_model_node
00768 (
00769     SVM_model_node*   node,
00770     const Vector_u32* labels, 
00771     const Matrix_i32* markers
00772 )
00773 {
00774     uint32_t          i, j;
00775     uint32_t          label[2];
00776     const Vector_u32* altlabel[2];
00777     Vector_u32*       node_labels  = NULL;
00778     Matrix_i32*       node_markers = NULL;
00779     Error*            err;
00780     int               slen = 256;
00781     char              str[slen];
00782 
00783     for (i = 0; i < node->num_models; i++)
00784     {
00785         label[0] = node->labels[0]->elts[ i ];
00786         label[1] = node->labels[1]->elts[ i ];
00787 
00788         altlabel[0] = node->altlabels[0][ i ];
00789         altlabel[1] = node->altlabels[1][ i ];
00790 
00791         if ((err = create_model_training_data(&node_labels, &node_markers, 
00792                         labels, markers, label, altlabel)))
00793         {
00794             snprintf(str, slen, "%s: %s", node->model_fnames[i], err->msg);
00795             return JWSC_EARG(str);
00796         }
00797 
00798         train_svm_model(&(node->models[ i ]), node_labels, node_markers,
00799                 node->cost->elts[ i ], node->gamma->elts[ i ]);
00800 
00801         for (j = 0; j < 2; j++)
00802         {
00803             if (node->subtrees[ j ][ i ])
00804             {
00805                 if ((err = train_svm_model_node(node->subtrees[ j ][ i ],
00806                                 labels, markers)))
00807                 {
00808                     return err;
00809                 }
00810             }
00811         }
00812     }
00813 
00814     free_vector_u32(node_labels);
00815     free_matrix_i32(node_markers);
00816 
00817     return NULL;
00818 }
00819 
00820 
00828 Error* train_svm_model_tree
00829 (
00830     SVM_model_tree**  tree_out,
00831     const Vector_u32* labels, 
00832     const Matrix_i32* markers,
00833     const char*       tree_xml_fname,
00834     const char*       tree_dtd_fname
00835 )
00836 {
00837     xmlDoc*  xml_doc;
00838     xmlNode* xml_node = NULL;
00839     xmlNode* it;
00840     Error*   err;
00841     int      slen = 256;
00842     char     str[slen];
00843 
00844     assert(tree_xml_fname);
00845 
00846     if ((err = read_svm_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname)))
00847     {
00848         return err;
00849     }
00850 
00851     for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next)
00852     {
00853         if (it->type == XML_ELEMENT_NODE && 
00854                 (XMLStrEqual(it->name, "binary-model") ||
00855                  XMLStrEqual(it->name, "one-vs-all-model") ||
00856                  XMLStrEqual(it->name, "one-vs-one-model")))
00857         {
00858             xml_node = it;
00859             break;
00860         }
00861     }
00862 
00863     assert(xml_node);
00864 
00865     if ((err = create_svm_model_tree_from_xml_node(tree_out, NULL, 0, 
00866                     xml_node)))
00867     {
00868         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
00869         return JWSC_EARG(str);
00870     }
00871 
00872     if ((err = train_svm_model_node(*tree_out, labels, markers)))
00873     {
00874         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
00875         return JWSC_EARG(str);
00876     }
00877 
00878     return NULL;
00879 }
00880 
00881 
00883 static double get_predicted_label_confidence
00884 (
00885     uint32_t         label,
00886     double*          confs,
00887     const SVM_model* model
00888 )   
00889 {   
00890     uint32_t l, num_labels;
00891     int*     labels;
00892     double   conf;
00893 
00894     num_labels = (uint32_t)svm_get_nr_class(model->svm);
00895     assert(labels = malloc(num_labels*sizeof(int)));
00896     svm_get_labels(model->svm, labels);
00897 
00898     for (l = 0; l < num_labels; l++)
00899     {
00900         if (labels[ l ] == label)
00901         {
00902            conf = confs[ l ];
00903         }
00904     }
00905 
00906     free(labels);
00907 
00908     return conf;
00909 }
00910 
00911 
00913 static uint32_t get_best_one_against_all_model
00914 (
00915     const SVM_model_tree* tree,
00916     Vector_u32*           labels,
00917     Vector_d*             confs
00918 )
00919 {
00920     uint32_t    label, other_label;
00921     uint32_t    num_models;
00922     uint32_t    m;
00923     uint32_t    num_labels;
00924     int32_t     best_model           = -1;
00925     double      max_conf             = 0;
00926     double      min_conf             = 1;
00927     double      label_conf;
00928     double      normalizing_c        = 0;
00929 
00930     Vector_d* label_conf_sum = NULL;
00931 
00932     num_models = tree->num_models;
00933     num_labels = get_num_haplo_groups();
00934 
00935     assert(num_models > 1);
00936 
00937     create_init_vector_d(&label_conf_sum, num_labels, 0);
00938 
00939     /* Use the sum the confidence values for each predicted label to decide 
00940      * which is best. If there is a tie, the first occuring label is choosen. */
00941     for (m = 0; m < num_models; m++)
00942     {
00943         label = labels->elts[ m ];
00944         label_conf = confs->elts[ m ];
00945 
00946         if (tree->labels[0]->elts[ m ] == label)
00947         {
00948             label_conf_sum->elts[ label ] += label_conf;
00949             other_label = tree->labels[1]->elts[ m ];
00950         }
00951         else
00952         {
00953             label_conf_sum->elts[ tree->labels[0]->elts[ m ] ] += 1-label_conf;
00954             other_label = tree->labels[0]->elts[ m ];
00955         }
00956     }
00957 
00958     for (m = 0; m < num_models; m++)
00959     {
00960         label = labels->elts[ m ];
00961         label_conf = label_conf_sum->elts[ label ];
00962 
00963         if ((label == tree->labels[0]->elts[ m ]) && (max_conf < label_conf))
00964         {
00965             max_conf = label_conf;
00966             best_model = m;
00967         }
00968     }
00969 
00970     /* It's possible that nothing was choosen because the 'other_label' was  
00971      * selected in all groups. In this case, choose the predicted label whose
00972      * model had the least confidence in its 'other_label' prediction. */
00973     if (best_model < 0)
00974     {
00975         for (m = 0; m < num_models; m++)
00976         {
00977             label = labels->elts[ m ];
00978             assert(label == tree->labels[1]->elts[ m ]);
00979             label_conf = confs->elts[ m ];
00980 
00981             if (min_conf > label_conf)
00982             {
00983                 min_conf = label_conf;
00984                 best_model = m;
00985             }
00986         }
00987 
00988         /* Hack the predicted_labels and confidence data structures to reflect
00989          * that 'other_label' was not choosen. */
00990         labels->elts[ best_model ] = tree->labels[0]->elts[ best_model ];
00991 
00992         label_conf = confs->elts[ best_model ];
00993         confs->elts[ best_model ] = 1.0 - label_conf;
00994     }
00995 
00996     /* Need to normalize the confidence value for the best label prediction so 
00997      * that it is a probability. At this level in the tree, one of the labels 
00998      * will be chosen, so the confidence value for these labels needs to sum
00999      * to one. */
01000     for (label = 0; label < num_labels; label++)
01001     {
01002         normalizing_c += label_conf_sum->elts[ label ];
01003     }
01004     assert(normalizing_c > 1.0e-16);
01005     confs->elts[ best_model ] /= normalizing_c;
01006 
01007     free_vector_d(label_conf_sum);
01008 
01009     return best_model;
01010 }
01011 
01012 
01014 static uint32_t get_best_one_against_one_model
01015 (
01016     const SVM_model_tree* tree,
01017     Vector_u32*           labels,
01018     Vector_d*             confs
01019 )
01020 {
01021     uint32_t    m, num_models;
01022     uint32_t    l, num_labels;
01023     uint32_t    best_model;
01024     uint32_t    max;
01025     uint32_t    the_winner;
01026     uint32_t   num_winners;
01027     double max_conf;
01028 
01029     Vector_u32* winners = NULL;
01030 
01031     num_models = tree->num_models;
01032     num_labels = get_num_haplo_groups();
01033 
01034     assert(num_models > 1);
01035 
01036     create_init_vector_u32(&winners, num_labels, 0);
01037 
01038     for (m = 0; m < num_models; m++)
01039     {
01040         winners->elts[ labels->elts[ m ] ]++;
01041     }
01042 
01043     max = 0;
01044     the_winner = 0;
01045     num_winners = 0;
01046 
01047     for (l = 0; l < num_labels; l++)
01048     {
01049         if (winners->elts[ l ] > max)
01050         {
01051             the_winner = l;
01052             max = winners->elts[ l ];
01053         }
01054     }
01055     for (l = 0; l < num_labels; l++)
01056     {
01057         if (max == winners->elts[ l ])
01058         {
01059             num_winners++;
01060         }
01061     }
01062 
01063     assert(max > 0);
01064 
01065     max_conf = 0;
01066 
01067     for (m = 0; m < num_models; m++)
01068     {
01069         if (labels->elts[ m ] == the_winner && confs->elts[ m ] > max_conf)
01070         {
01071             best_model = m;
01072         }
01073     }
01074 
01075     free_vector_u32(winners);
01076 
01077     return best_model;
01078 }
01079 
01080 
01082 static Error* recursively_predict_label_in_model_tree
01083 (
01084     uint32_t*                label_out,
01085     double*                  confidence_out,
01086     const SVM_model_tree*    tree,
01087     const struct svm_node*   markers
01088 )
01089 {
01090     uint32_t i;
01091     uint32_t m, num_models;
01092     uint32_t subtree_label;
01093     Error*   err;
01094 
01095     Vector_u32* labels = NULL;
01096     Vector_d*   confs  = NULL;
01097     Vector_d*   v      = NULL;
01098 
01099     num_models = tree->num_models;
01100 
01101     create_vector_u32(&labels, num_models);
01102     create_vector_d(&confs, num_models);
01103 
01104     for (m = 0; m < num_models; m++)
01105     {
01106         create_zero_vector_d(&v, svm_get_nr_class(tree->models[ m ]->svm));
01107 
01108         if ((err = predict_label_with_svm_model(&(labels->elts[ m ]),
01109                         &(v->elts), markers, tree->models[ m ])))
01110         {
01111             return err;
01112         }
01113 
01114         confs->elts[ m ] = get_predicted_label_confidence(labels->elts[ m ], 
01115                 v->elts, tree->models[ m ]);
01116     }
01117     free_vector_d(v);
01118 
01119     // If there are multiple models in the node, get the prediction from
01120     // the best one.
01121     m = (num_models > 1) ? get_best_one_against_all_model(tree, labels, confs) 
01122         : 0;
01123 
01124     *label_out = labels->elts[ m ];
01125     *confidence_out *= confs->elts[ m ];
01126 
01127     free_vector_u32(labels);
01128     free_vector_d(confs);
01129 
01130     // If the best model has a subtree, recursively predict in the subtree.
01131     for (m = 0; m < num_models; m++)
01132     {
01133         for (i = 0; i < 2; i++)
01134         {
01135             if (!(tree->subtrees[ i ][ m ]))
01136                 continue;
01137 
01138             subtree_label = tree->subtrees[ i ][ m ]->parent_label;
01139 
01140             if (subtree_label == *label_out)
01141             {
01142                 if ((err = recursively_predict_label_in_model_tree(label_out,
01143                                 confidence_out, tree->subtrees[ i ][ m ],
01144                                 markers)))
01145                 {
01146                     return err;
01147                 }
01148             }
01149         }
01150     }
01151 
01152     return NULL;
01153 }
01154 
01155 
01169 Error* predict_labels_with_svm_model_tree
01170 (
01171     Vector_u32**          labels_out,
01172     Vector_d**            confidence_out,
01173     const Matrix_i32*     markers,
01174     const SVM_model_tree* tree
01175 )
01176 {
01177     uint32_t s, num_samples;
01178     uint32_t m, num_markers;
01179 
01180     Vector_u32*      labels;
01181     Vector_d*        confidence;
01182     struct svm_node* markers_v;
01183     Error*           e;
01184 
01185     num_samples = markers->num_rows;
01186     num_markers = markers->num_cols;
01187 
01188     create_vector_u32(labels_out, num_samples);
01189     labels = *labels_out;
01190 
01191     create_init_vector_d(confidence_out, num_samples, 1.0);
01192     confidence = *confidence_out;
01193 
01194     markers_v = NULL;
01195     assert(markers_v = malloc((num_markers+1)*sizeof(struct svm_node)));
01196 
01197     for (s = 0; s < num_samples; s++)
01198     {
01199         for (m  = 0; m  < num_markers; m ++)
01200         {
01201             markers_v[ m  ].index = m +1;
01202             markers_v[ m  ].value = markers->elts[ s ][ m  ];
01203         }
01204         markers_v[ m ].index = -1;
01205 
01206         if ((e = recursively_predict_label_in_model_tree(
01207                         &(labels->elts[ s ]), 
01208                         &(confidence->elts[ s ]), tree, markers_v)) 
01209                 != NULL)
01210         {
01211             return e;
01212         }
01213     }
01214 
01215     free(markers_v);
01216 
01217     return NULL;
01218 }
01219 
01220 
01222 static Error* read_svm_model_node
01223 (
01224      SVM_model_node* node, 
01225      const char*     model_dirname
01226 )
01227 {
01228     uint32_t i, j;
01229     char     buf[1024] = {0};
01230     Error*   err;
01231 
01232     for (i = 0; i < node->num_models; i++)
01233     {
01234         snprintf(buf, 1024, "%s/%s", model_dirname, node->model_fnames[ i ]);
01235         if ((err = read_svm_model(&(node->models[ i ]), buf)))
01236         {
01237             return err;
01238         }
01239 
01240         for (j = 0; j < 2; j++)
01241         {
01242             if (node->subtrees[ j ][ i ])
01243             {
01244                 if ((err = read_svm_model_node(node->subtrees[ j ][ i ],
01245                                 model_dirname)))
01246                 {
01247                     return err;
01248                 }
01249             }
01250         }
01251     }
01252 
01253     return NULL;
01254 }
01255 
01256 
01263 Error* read_svm_model_tree
01264 (
01265     SVM_model_tree** tree_out,
01266     const char*      tree_xml_fname,
01267     const char*      tree_dtd_fname,
01268     const char*      model_dirname
01269 )
01270 {
01271     xmlDoc*  xml_doc;
01272     xmlNode* xml_node = NULL;
01273     xmlNode* it;
01274     Error*   err;
01275     int      slen = 256;
01276     char     str[slen];
01277 
01278     assert(tree_xml_fname && model_dirname);
01279 
01280     if ((err = read_svm_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname)))
01281     {
01282         return err;
01283     }
01284 
01285     for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next)
01286     {
01287         if (it->type == XML_ELEMENT_NODE && 
01288                 (XMLStrEqual(it->name, "binary-model") ||
01289                  XMLStrEqual(it->name, "one-vs-all-model") ||
01290                  XMLStrEqual(it->name, "one-vs-one-model")))
01291         {
01292             xml_node = it;
01293             break;
01294         }
01295     }
01296 
01297     assert(xml_node);
01298 
01299     if ((err = create_svm_model_tree_from_xml_node(tree_out, NULL, 0,
01300                     xml_node)))
01301     {
01302         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
01303         return JWSC_EARG(str);
01304     }
01305 
01306     if ((err = read_svm_model_node(*tree_out, model_dirname)))
01307     {
01308         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
01309         return JWSC_EARG(str);
01310     }
01311 
01312     return NULL;
01313 }
01314 
01315 
01320 Error* write_svm_model_tree
01321 (
01322     const SVM_model_tree* tree,
01323     const char*           model_dirname
01324 )
01325 {
01326     uint32_t i, j;
01327     char     buf[1024] = {0};
01328     Error*   err;
01329     int      slen = 256;
01330     char     str[slen];
01331 
01332     for (i = 0; i < tree->num_models; i++)
01333     {
01334         snprintf(buf, 1024, "%s/%s", model_dirname, tree->model_fnames[ i ]);
01335         if ((err = write_svm_model(tree->models[ i ], buf)))
01336         {
01337             return err;
01338         }
01339 
01340         for (j = 0; j < 2; j++)
01341         {
01342             if (tree->subtrees[ j ][ i ])
01343             {
01344                 if ((err = write_svm_model_tree(tree->subtrees[ j ][ i ], 
01345                                 model_dirname)))
01346                 {
01347                     snprintf(str, slen, "%s: %s", tree->model_fnames[ i ], 
01348                             err->msg);
01349                     return JWSC_EARG(str);
01350                 }
01351             }
01352         }
01353     }
01354 
01355     return NULL;
01356 }
01357 
01358 
01363 static Error* write_svm_model_node_training_data
01364 (
01365     SVM_model_node*   node,
01366     const Vector_u32* labels, 
01367     const Matrix_i32* markers,
01368     const char*       data_dirname
01369 )
01370 {
01371     uint32_t          i, j;
01372     uint32_t          label[2];
01373     char              buf[1024] = {0};
01374     const Vector_u32* altlabel[2];
01375     Vector_u32*       node_labels  = NULL;
01376     Matrix_i32*       node_markers = NULL;
01377     Error*            err;
01378     int               slen = 256;
01379     char              str[slen];
01380 
01381     for (i = 0; i < node->num_models; i++)
01382     {
01383         label[0] = node->labels[0]->elts[ i ];
01384         label[1] = node->labels[1]->elts[ i ];
01385 
01386         altlabel[0] = node->altlabels[0][ i ];
01387         altlabel[1] = node->altlabels[1][ i ];
01388 
01389         if ((err = create_model_training_data(&node_labels, &node_markers, 
01390                         labels, markers, label, altlabel)))
01391         {
01392             snprintf(str, slen, "%s: %s", node->model_fnames[i], err->msg);
01393             return JWSC_EARG(str);
01394         }
01395 
01396         snprintf(buf, 1024, "%s/%s", data_dirname, node->model_fnames[ i ]);
01397 
01398         if ((err = write_svm_model_training_data(node_labels, node_markers,
01399                         buf)))
01400         {
01401             return err;
01402         }
01403 
01404         for (j = 0; j < 2; j++)
01405         {
01406             if (node->subtrees[ j ][ i ])
01407             {
01408                 if ((err = write_svm_model_node_training_data(
01409                                 node->subtrees[ j ][ i ], labels, markers,
01410                                 data_dirname)))
01411                 {
01412                     return err;
01413                 }
01414             }
01415         }
01416     }
01417 
01418     free_vector_u32(node_labels);
01419     free_matrix_i32(node_markers);
01420 
01421     return NULL;
01422 }
01423 
01424 
01435 Error* write_svm_model_tree_training_data
01436 (
01437     const Vector_u32* labels, 
01438     const Matrix_i32* markers,
01439     const char*       tree_xml_fname,
01440     const char*       tree_dtd_fname,
01441     const char*       data_dirname
01442 )
01443 {
01444     SVM_model_tree* tree = NULL;
01445     xmlDoc*  xml_doc;
01446     xmlNode* xml_node = NULL;
01447     xmlNode* it;
01448     Error*   err;
01449     int      slen = 256;
01450     char     str[slen];
01451 
01452     assert(tree_xml_fname);
01453 
01454     if ((err = read_svm_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname)))
01455     {
01456         return err;
01457     }
01458 
01459     for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next)
01460     {
01461         if (it->type == XML_ELEMENT_NODE && 
01462                 (XMLStrEqual(it->name, "binary-model") ||
01463                  XMLStrEqual(it->name, "one-vs-one-model") ||
01464                  XMLStrEqual(it->name, "one-vs-all-model")))
01465         {
01466             xml_node = it;
01467             break;
01468         }
01469     }
01470 
01471     assert(xml_node);
01472 
01473     if ((err = create_svm_model_tree_from_xml_node(&tree, NULL, 0, xml_node)))
01474     {
01475         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
01476         return JWSC_EARG(str);
01477     }
01478 
01479     if ((err = write_svm_model_node_training_data(tree, labels, markers,
01480                     data_dirname)))
01481     {
01482         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
01483         return JWSC_EARG(str);
01484     }
01485 
01486     free_svm_model_tree(tree);
01487 
01488     return NULL;
01489 }
01490 
01491 
01493 void free_svm_model_tree(SVM_model_tree* tree)
01494 {
01495     uint32_t i;
01496 
01497     if (!tree)
01498         return;
01499 
01500     for (i = 0; i < tree->num_models; i++) 
01501     {
01502         free_svm_model_tree(tree->subtrees[0][ i ]);
01503         free_svm_model_tree(tree->subtrees[1][ i ]);
01504         free_vector_u32(tree->altlabels[0][ i ]);
01505         free_vector_u32(tree->altlabels[1][ i ]);
01506         free_svm_model(tree->models[ i ]);
01507         free(tree->model_fnames[ i ]);
01508     }
01509 
01510     free(tree->subtrees[0]);
01511     free(tree->subtrees[1]);
01512     free_vector_u32(tree->labels[0]);
01513     free_vector_u32(tree->labels[1]);
01514     free(tree->altlabels[0]);
01515     free(tree->altlabels[1]);
01516     free_vector_d(tree->cost);
01517     free_vector_d(tree->gamma);
01518     free(tree->models);
01519     free(tree->model_fnames);
01520     free(tree);
01521 }