Haplo Prediction
predict haplogroups
weka.c
Go to the documentation of this file.
00001 /*
00002  * This work is licensed under a Creative Commons 
00003  * Attribution-Noncommercial-Share Alike 3.0 United States License.
00004  * 
00005  *    http://creativecommons.org/licenses/by-nc-sa/3.0/us/
00006  * 
00007  * You are free:
00008  * 
00009  *    to Share - to copy, distribute, display, and perform the work
00010  *    to Remix - to make derivative works
00011  * 
00012  * Under the following conditions:
00013  * 
00014  *    Attribution. You must attribute the work in the manner specified by the
00015  *    author or licensor (but not in any way that suggests that they endorse you
00016  *    or your use of the work).
00017  * 
00018  *    Noncommercial. You may not use this work for commercial purposes.
00019  * 
00020  *    Share Alike. If you alter, transform, or build upon this work, you may
00021  *    distribute the resulting work only under the same or similar license to
00022  *    this one.
00023  * 
00024  * For any reuse or distribution, you must make clear to others the license
00025  * terms of this work. The best way to do this is by including this header.
00026  * 
00027  * Any of the above conditions can be waived if you get permission from the
00028  * copyright holder.
00029  * 
00030  * Apart from the remix rights granted under this license, nothing in this
00031  * license impairs or restricts the author's moral rights.
00032  */
00033 
00034 
00046 #include <config.h>
00047 
00048 #include <stdlib.h>
00049 #include <stdio.h>
00050 #include <assert.h>
00051 #include <inttypes.h>
00052 #include <unistd.h>
00053 #include <string.h>
00054 #include <errno.h>
00055 
00056 #include <libxml/tree.h>
00057 #include <libxml/parser.h>
00058 #include <libxml/valid.h>
00059 
00060 #ifdef HAPLO_HAVE_DMALLOC
00061 #include <dmalloc.h>
00062 #endif
00063 
00064 #include <jwsc/base/error.h>
00065 #include <jwsc/base/file_io.h>
00066 #include <jwsc/vector/vector.h>
00067 #include <jwsc/matrix/matrix.h>
00068 
00069 #include "xml.h"
00070 #include "haplo_groups.h"
00071 #include "weka.h"
00072 
00073 
00075 static Error* read_weka_model_labels
00076 (
00077     char**    labels_str_out, 
00078     const char* labels_fname
00079 )
00080 {
00081     static char buf[4096] = {0};
00082     char* buf_ptr;
00083     char  c;
00084     FILE* fp;
00085     int   slen = 1024;
00086     char  str[slen];
00087 
00088     if ((fp = fopen(labels_fname, "r")) == NULL)
00089     {
00090         snprintf(str, slen, "%s: %s", labels_fname, strerror(errno));
00091         return JWSC_EIO(str);
00092     }
00093 
00094     if (!skip_fp_spaces_and_comments(fp))
00095     {
00096         snprintf(str, slen, "%s: %s", labels_fname, 
00097                 "Labels file not formatted properly");
00098         return JWSC_EARG(str);
00099     }
00100 
00101     buf_ptr = buf;
00102     while ((c = fgetc(fp)) != EOF && c != '\n' && buf_ptr != (buf+4096))
00103     {
00104         *buf_ptr = c;
00105         buf_ptr++;
00106     }
00107     if (buf_ptr != (buf+4096))
00108     {
00109         *buf_ptr = 0;
00110     }
00111 
00112     if (*labels_str_out)
00113     {
00114         free(*labels_str_out);
00115     }
00116     *labels_str_out = malloc((strlen(buf)+1)*sizeof(char));
00117     strcpy(*labels_str_out, buf);
00118 
00119     if (fclose(fp) != 0)
00120     {
00121         snprintf(str, slen, "%s: %s", labels_fname, strerror(errno));
00122         return JWSC_EIO(str);
00123     }
00124 
00125     return NULL;
00126 }
00127 
00128 
00140 Error* train_weka_j48_model
00141 (
00142     const Vector_u32* labels, 
00143     const Matrix_i32* markers,
00144     const char*       labels_fname,
00145     const char*       model_fname,
00146     const char*       weka_jar_fname
00147 )
00148 {
00149     static char script[4096] = {0};
00150     static char tmp_data_csv_fname[256] = {0};
00151     const char* label_str;
00152 
00153     FILE*    fp;
00154     uint32_t pid;
00155     uint32_t s, num_samples;
00156     uint32_t m, num_markers;
00157     int      slen = 1024;
00158     char     str[slen];
00159 
00160     pid = getpid();
00161 
00162     snprintf(tmp_data_csv_fname, 255, "/tmp/.weka_j48_train_data_csv-%u", pid);
00163     if (!(fp = fopen(tmp_data_csv_fname, "w")))
00164     {
00165         snprintf(str, slen, "%s: %s", tmp_data_csv_fname, strerror(errno));
00166         return JWSC_EIO(str);
00167     }
00168 
00169     num_samples = markers->num_rows;
00170     num_markers = markers->num_cols;
00171 
00172     for (s = 0; s < num_samples; s++)
00173     {
00174         assert(lookup_haplo_group_label_from_index(&label_str, 
00175                     labels->elts[ s ]) == NULL);
00176         fprintf(fp, "%s", label_str);
00177         for (m = 0; m < num_markers; m++)
00178         {
00179             fprintf(fp, ",%u", markers->elts[ s ][ m ]);
00180         }
00181         fprintf(fp, "\n");
00182     }
00183 
00184     if (fclose(fp) != 0)
00185     {
00186         snprintf(str, slen, "%s: %s", tmp_data_csv_fname, strerror(errno));
00187         return JWSC_EIO(str);
00188     }
00189 
00190     snprintf(script, 4096, 
00191             "WEKA_JAR=\"%s\"; \
00192              JAVA=\"%s\"; \
00193              EGREP=\"%s\"; \
00194              AWK=\"%s\"; \
00195              SED=\"%s\"; \
00196              TMP_TRAIN_DATA_ARFF=\"/tmp/.weka_j48_train_data_arff-%u\"; \
00197              TMP_TRAIN_DATA_CSV=\"%s\"; \
00198              JVM_OPTS=\"-Xmx1024m -cp $WEKA_JAR\"; \
00199              CSV_CONVERTER=\"weka.core.converters.CSVLoader\"; \
00200              J48=\"weka.classifiers.trees.J48\"; \
00201              J48_OPTS=\"-v -t $TMP_TRAIN_DATA_ARFF -x 2 -c 1 -d %s\"; \
00202              $JAVA $JVM_OPTS $CSV_CONVERTER $TMP_TRAIN_DATA_CSV \
00203                    > $TMP_TRAIN_DATA_ARFF; \
00204              $EGREP '@attribute' $TMP_TRAIN_DATA_ARFF \
00205                    | $SED '2,$d' \
00206                    | $AWK '{for (i=3;i<=NF;i++) printf \"%%s \", $i; printf \"\\n\"}' \
00207                    > %s; \
00208              $JAVA $JVM_OPTS $J48 $J48_OPTS >> /dev/null; \
00209              rm -f $TMP_TRAIN_DATA_ARFF; \
00210              rm -f $TMP_TRAIN_DATA_CSV", 
00211              weka_jar_fname, HAPLO_JAVA, HAPLO_EGREP, HAPLO_AWK, HAPLO_SED,
00212              pid, tmp_data_csv_fname, model_fname, labels_fname);
00213 
00214     system(script);
00215 
00216     return NULL;
00217 }
00218 
00219 
00231 Error* train_weka_part_model
00232 (
00233     const Vector_u32* labels, 
00234     const Matrix_i32* markers,
00235     const char*       labels_fname,
00236     const char*       model_fname,
00237     const char*       weka_jar_fname
00238 )
00239 {
00240     static char script[4096] = {0};
00241     static char tmp_data_csv_fname[256] = {0};
00242     const char* label_str;
00243 
00244     FILE*    fp;
00245     uint32_t pid;
00246     uint32_t s, num_samples;
00247     uint32_t m, num_markers;
00248     int      slen = 1024;
00249     char     str[slen];
00250 
00251     pid = getpid();
00252 
00253     snprintf(tmp_data_csv_fname, 255, "/tmp/.weka_part_train_data_csv-%u", pid);
00254     if (!(fp = fopen(tmp_data_csv_fname, "w")))
00255     {
00256         snprintf(str, slen, "%s: %s", tmp_data_csv_fname, 
00257                 strerror(errno));
00258         return JWSC_EIO(str);
00259     }
00260 
00261     num_samples = markers->num_rows;
00262     num_markers = markers->num_cols;
00263 
00264     for (s = 0; s < num_samples; s++)
00265     {
00266         assert(lookup_haplo_group_label_from_index(&label_str, 
00267                     labels->elts[ s ]) == NULL);
00268         fprintf(fp, "%s", label_str);
00269         for (m = 0; m < num_markers; m++)
00270         {
00271             fprintf(fp, ",%u", markers->elts[ s ][ m ]);
00272         }
00273         fprintf(fp, "\n");
00274     }
00275 
00276     if (fclose(fp) != 0)
00277     {
00278         snprintf(str, slen, "%s: %s", tmp_data_csv_fname, strerror(errno));
00279         return JWSC_EIO(str);
00280     }
00281 
00282     snprintf(script, 4096, 
00283             "WEKA_JAR=\"%s\"; \
00284              JAVA=\"%s\"; \
00285              EGREP=\"%s\"; \
00286              AWK=\"%s\"; \
00287              SED=\"%s\"; \
00288              TMP_TRAIN_DATA_ARFF=\"/tmp/.weka_part_train_data_arff-%u\"; \
00289              TMP_TRAIN_DATA_CSV=\"%s\"; \
00290              JVM_OPTS=\"-Xmx1024m -cp $WEKA_JAR\"; \
00291              CSV_CONVERTER=\"weka.core.converters.CSVLoader\"; \
00292              PART=\"weka.classifiers.rules.PART\"; \
00293              PART_OPTS=\"-v -t $TMP_TRAIN_DATA_ARFF -x 2 -c 1 -d %s\"; \
00294              $JAVA $JVM_OPTS $CSV_CONVERTER $TMP_TRAIN_DATA_CSV \
00295                    > $TMP_TRAIN_DATA_ARFF; \
00296              $EGREP '@attribute' $TMP_TRAIN_DATA_ARFF \
00297                    | $SED '2,$d' \
00298                    | $AWK '{for (i=3;i<=NF;i++) printf \"%%s \", $i; printf \"\\n\"}' \
00299                    > %s; \
00300              $JAVA $JVM_OPTS $PART $PART_OPTS >> /dev/null; \
00301              rm -f $TMP_TRAIN_DATA_ARFF; \
00302              rm -f $TMP_TRAIN_DATA_CSV", 
00303              weka_jar_fname, HAPLO_JAVA, HAPLO_EGREP, HAPLO_AWK, HAPLO_SED,
00304              pid, tmp_data_csv_fname, model_fname, labels_fname);
00305 
00306     system(script);
00307 
00308     return NULL;
00309 }
00310 
00311 
00322 Error* predict_labels_with_weka_j48_model
00323 (
00324     Vector_u32**      labels_out,
00325     Vector_d**        confs_out,
00326     const Matrix_i32* markers,
00327     const char*       labels_fname,
00328     const char*       model_fname,
00329     const char*       weka_jar_fname
00330 )
00331 {
00332     static char script[4096] = {0};
00333     static char tmp_data_fname[256] = {0};
00334     static char tmp_result_fname[256] = {0};
00335     char label_buf[256] = {0};
00336     char* label_buf_ptr;
00337     char  c;
00338 
00339     FILE*    fp;
00340     uint32_t pid;
00341     uint32_t s, num_samples;
00342     uint32_t m, num_markers;
00343     uint32_t label;
00344     double   conf;
00345     char*    labels_str = NULL;
00346     Error*   err;
00347     int      slen = 1024;
00348     char     str[slen];
00349 
00350     if ((err = read_weka_model_labels(&labels_str, labels_fname)))
00351     {
00352         return err;
00353     }
00354 
00355     pid = getpid();
00356     num_samples = markers->num_rows;
00357     num_markers = markers->num_cols;
00358 
00359     create_vector_u32(labels_out, num_samples);
00360     create_vector_d(confs_out, num_samples);
00361 
00362     snprintf(tmp_data_fname, 256, "/tmp/.weka_j48_predict_data-%u", pid);
00363     snprintf(tmp_result_fname, 256, "/tmp/.weka_j48_predict_result-%u", pid);
00364 
00365     if ((fp = fopen(tmp_data_fname, "w")) == NULL)
00366     {
00367         snprintf(str, slen, "%s: %s", tmp_data_fname, strerror(errno));
00368         return JWSC_EIO(str);
00369     }
00370 
00371     fprintf(fp, "@relation %s\n\n", tmp_data_fname);
00372 
00373     fprintf(fp, "@attribute label %s\n", labels_str);
00374 
00375     for (m = 0; m < num_markers; m++)
00376     {
00377         fprintf(fp, "@attribute %u numeric\n", m);
00378     }
00379 
00380     fprintf(fp, "\n@data\n");
00381     for (s = 0; s < num_samples; s++)
00382     {
00383         fprintf(fp, "?");
00384         for (m = 0; m < num_markers; m++)
00385         {
00386             if (markers->elts[ s ][ m ] > 0)
00387             {
00388                 fprintf(fp, ",%u", markers->elts[ s ][ m ]);
00389             }
00390             else
00391             {
00392                 fprintf(fp, ",?");
00393             }
00394         }
00395         fprintf(fp, "\n");
00396     }
00397 
00398     if (fclose(fp) != 0)
00399     {
00400         snprintf(str, slen, "%s: %s", tmp_data_fname, strerror(errno));
00401         return JWSC_EIO(str);
00402     }
00403 
00404     snprintf(script, 4096, 
00405             "WEKA_JAR=\"%s\"; \
00406              JAVA=\"%s\"; \
00407              AWK=\"%s\"; \
00408              SED=\"%s\"; \
00409              TMP_DATA=\"%s\"; \
00410              TMP_RESULT=\"%s\"; \
00411              JVM_OPTS=\"-Xmx1024m -cp $WEKA_JAR\"; \
00412              J48=\"weka.classifiers.trees.J48\"; \
00413              J48_OPTS=\"-c 1 -T $TMP_DATA -l %s -p 0\"; \
00414              $JAVA $JVM_OPTS $J48 $J48_OPTS \
00415                     | $SED '$d' \
00416                     | $SED 's,^[^ ]\\+ \\(.*\\) \\([^ ]\\+\\) [^ ]\\+[ ]*$,\\1\\n\\2,' \
00417                     > $TMP_RESULT",
00418              weka_jar_fname, HAPLO_JAVA, HAPLO_AWK, HAPLO_SED, tmp_data_fname,
00419              tmp_result_fname, model_fname);
00420     system(script);
00421 
00422     if (!(fp = fopen(tmp_result_fname, "r")))
00423     {
00424         snprintf(str, slen, "%s: %s", tmp_result_fname, strerror(errno));
00425         return JWSC_EIO(str);
00426     }
00427 
00428     for (s = 0; s < num_samples; s++)
00429     {
00430         label_buf_ptr = label_buf;
00431         while ((c = fgetc(fp)) != EOF && c != '\n' && 
00432                 label_buf_ptr != (label_buf+256))
00433         {
00434             *label_buf_ptr = c;
00435             label_buf_ptr++;
00436         }
00437         if (label_buf_ptr != (label_buf+256))
00438         {
00439             *label_buf_ptr = 0;
00440         }
00441 
00442         if (fscanf(fp, "%lf\n", &conf) != 1)
00443         {
00444             snprintf(str, slen, "%s: Sample %d: %s", tmp_result_fname, 
00445                     s+1, "Invalid output from Weka J48");
00446             return JWSC_EIO(str);
00447         }
00448         if ((err = lookup_haplo_group_index_from_label(&label, label_buf)))
00449         {
00450             snprintf(str, slen, "%s: Sample %d: %s", tmp_result_fname, s+1,
00451                     err->msg);
00452             return JWSC_EIO(str);
00453         }
00454 
00455         (*labels_out)->elts[ s ] = label;
00456         (*confs_out)->elts[ s ] = conf;
00457     }
00458 
00459     if (fclose(fp) != 0)
00460     {
00461         snprintf(str, slen, "%s: %s", tmp_result_fname, strerror(errno));
00462         return JWSC_EIO(str);
00463     }
00464 
00465     snprintf(script, 4096, "rm -f %s", tmp_data_fname);
00466     system(script);
00467     snprintf(script, 4096, "rm -f %s", tmp_result_fname);
00468     system(script);
00469 
00470     free(labels_str);
00471 
00472     return NULL;
00473 }
00474 
00475 
00486 Error* predict_labels_with_weka_part_model
00487 (
00488     Vector_u32**      labels_out,
00489     Vector_d**        confs_out,
00490     const Matrix_i32* markers,
00491     const char*       labels_fname,
00492     const char*       model_fname,
00493     const char*       weka_jar_fname
00494 )
00495 {
00496     static char script[4096] = {0};
00497     static char tmp_data_fname[256] = {0};
00498     static char tmp_result_fname[256] = {0};
00499     char label_buf[256] = {0};
00500     char* label_buf_ptr;
00501     char  c;
00502 
00503     FILE*    fp;
00504     uint32_t pid;
00505     uint32_t s, num_samples;
00506     uint32_t m, num_markers;
00507     uint32_t label;
00508     double   conf;
00509     char*    labels_str = NULL;
00510     Error*   err;
00511     int      slen = 1024;
00512     char     str[slen];
00513 
00514     if ((err = read_weka_model_labels(&labels_str, labels_fname)))
00515     {
00516         return err;
00517     }
00518 
00519     pid = getpid();
00520     num_samples = markers->num_rows;
00521     num_markers = markers->num_cols;
00522 
00523     create_vector_u32(labels_out, num_samples);
00524     create_vector_d(confs_out, num_samples);
00525 
00526     snprintf(tmp_data_fname, 256, "/tmp/.weka_part_predict_data-%u", pid);
00527     snprintf(tmp_result_fname, 256, "/tmp/.weka_part_predict_result-%u", pid);
00528 
00529     if ((fp = fopen(tmp_data_fname, "w")) == NULL)
00530     {
00531         snprintf(str, slen, "%s: %s", tmp_data_fname, strerror(errno));
00532         return JWSC_EIO(str);
00533     }
00534 
00535     fprintf(fp, "@relation %s\n\n", tmp_data_fname);
00536 
00537     fprintf(fp, "@attribute label %s\n", labels_str);
00538 
00539     for (m = 0; m < num_markers; m++)
00540     {
00541         fprintf(fp, "@attribute %u numeric\n", m);
00542     }
00543 
00544     fprintf(fp, "\n@data\n");
00545     for (s = 0; s < num_samples; s++)
00546     {
00547         fprintf(fp, "?");
00548         for (m = 0; m < num_markers; m++)
00549         {
00550             if (markers->elts[ s ][ m ] > 0)
00551             {
00552                 fprintf(fp, ",%u", markers->elts[ s ][ m ]);
00553             }
00554             else
00555             {
00556                 fprintf(fp, ",?");
00557             }
00558         }
00559         fprintf(fp, "\n");
00560     }
00561 
00562     if (fclose(fp) != 0)
00563     {
00564         snprintf(str, slen, "%s: %s", tmp_data_fname, strerror(errno));
00565         return JWSC_EIO(str);
00566     }
00567 
00568     snprintf(script, 4096, 
00569             "WEKA_JAR=\"%s\"; \
00570              JAVA=\"%s\"; \
00571              AWK=\"%s\"; \
00572              SED=\"%s\"; \
00573              TMP_DATA=\"%s\"; \
00574              TMP_RESULT=\"%s\"; \
00575              JVM_OPTS=\"-Xmx1024m -cp $WEKA_JAR\"; \
00576              PART=\"weka.classifiers.rules.PART\"; \
00577              PART_OPTS=\"-c 1 -T $TMP_DATA -l %s -p 0\"; \
00578              $JAVA $JVM_OPTS $PART $PART_OPTS \
00579                     | $SED '$d' \
00580                     | $SED 's,^[^ ]\\+ \\(.*\\) \\([^ ]\\+\\) [^ ]\\+[ ]*$,\\1\\n\\2,' \
00581                     > $TMP_RESULT",
00582              weka_jar_fname, HAPLO_JAVA, HAPLO_AWK, HAPLO_SED, tmp_data_fname,
00583              tmp_result_fname, model_fname);
00584     system(script);
00585 
00586     if (!(fp = fopen(tmp_result_fname, "r")))
00587     {
00588         snprintf(str, slen, "%s: %s", tmp_result_fname, strerror(errno));
00589         return JWSC_EIO(str);
00590     }
00591 
00592     for (s = 0; s < num_samples; s++)
00593     {
00594         label_buf_ptr = label_buf;
00595         while ((c = fgetc(fp)) != EOF && c != '\n' && 
00596                 label_buf_ptr != (label_buf+256))
00597         {
00598             *label_buf_ptr = c;
00599             label_buf_ptr++;
00600         }
00601         if (label_buf_ptr != (label_buf+256))
00602         {
00603             *label_buf_ptr = 0;
00604         }
00605 
00606         if (fscanf(fp, "%lf\n", &conf) != 1)
00607         {
00608             snprintf(str, slen, "%s: Sample %d: %s", tmp_result_fname, 
00609                     s+1, "Invalid output from Weka PART");
00610             return JWSC_EIO(str);
00611         }
00612         if ((err = lookup_haplo_group_index_from_label(&label, label_buf)))
00613         {
00614             snprintf(str, slen, "%s: Sample %d: %s", tmp_result_fname, 
00615                     s+1, err->msg);
00616             return JWSC_EIO(str);
00617         }
00618 
00619         (*labels_out)->elts[ s ] = label;
00620         (*confs_out)->elts[ s ] = conf;
00621     }
00622 
00623     if (fclose(fp) != 0)
00624     {
00625         snprintf(str, slen, "%s: %s", tmp_result_fname, strerror(errno));
00626         return JWSC_EIO(str);
00627     }
00628 
00629     snprintf(script, 4096, "rm -f %s", tmp_data_fname);
00630     system(script);
00631     snprintf(script, 4096, "rm -f %s", tmp_result_fname);
00632     system(script);
00633 
00634     free(labels_str);
00635 
00636     return NULL;
00637 }
00638 
00639 
00641 static Error* read_weka_xml_doc
00642 (
00643     xmlDoc**    xml_doc_out,
00644     const char* xml_fname,
00645     const char* dtd_fname
00646 )
00647 {
00648     xmlParserCtxt* xml_parse_ctxt;
00649     xmlValidCtxt*  xml_valid_ctxt;
00650     xmlDtd*        xml_dtd;
00651     int            slen = 256;
00652     char           str[slen];
00653 
00654     assert(xml_parse_ctxt = xmlNewParserCtxt());
00655 
00656     if (!(*xml_doc_out = xmlCtxtReadFile(xml_parse_ctxt, xml_fname, NULL, 0)))
00657     {
00658         snprintf(str, slen, "%s: %s", xml_fname, "Could not parse file");
00659         return JWSC_EARG(str);
00660     } 
00661 
00662     xmlFreeParserCtxt(xml_parse_ctxt);
00663 
00664     if (dtd_fname)
00665     {
00666         assert(xml_valid_ctxt = xmlNewValidCtxt());
00667 
00668         if (!(xml_dtd = xmlParseDTD(NULL, (xmlChar*)dtd_fname)))
00669         {
00670             snprintf(str, slen, "%s: %s", dtd_fname, "Could not parse DTD");
00671             return JWSC_EARG(str);
00672         }
00673 
00674         if (!xmlValidateDtd(xml_valid_ctxt, *xml_doc_out, xml_dtd))
00675         {
00676             snprintf(str, slen, "%s: %s", xml_fname, "XML file not valid");
00677             return JWSC_EARG(str);
00678         }
00679 
00680         xmlFreeValidCtxt(xml_valid_ctxt);
00681         xmlFreeDtd(xml_dtd);
00682     }
00683 
00684     return NULL;
00685 }
00686 
00687 
00689 static Error* create_weka_model_tree_from_xml_node
00690 (
00691      Weka_model_tree**      tree_out, 
00692      const Weka_model_tree* parent,
00693      uint32_t               parent_label,
00694      xmlNode*               xml_node,
00695      const char*            model_dirname
00696 )
00697 {
00698     Weka_model_tree* tree;
00699     uint32_t i, j;
00700     uint32_t len;
00701     uint32_t label;
00702     uint32_t altlabel;
00703     uint32_t num_altlabels;
00704     xmlNode* it;
00705     xmlNode* itt;
00706     Error*   err;
00707     const char* fname;
00708     char*       fname_buf;
00709 
00710     if (*tree_out)
00711     {
00712         free_weka_model_tree(*tree_out);
00713     }
00714 
00715     assert(*tree_out = malloc(sizeof(Weka_model_tree)));
00716     tree = *tree_out;
00717 
00718     tree->parent       = parent;
00719     tree->parent_label = parent_label;
00720     tree->num_groups   = 0;
00721     tree->subtrees     = NULL;
00722     tree->labels       = NULL;
00723     tree->altlabels    = NULL;
00724     tree->labels_fname = NULL;
00725     tree->model_fname  = NULL;
00726 
00727     assert(XMLStrEqual(xml_node->name, "model"));
00728 
00729     for (it = xml_node->children; it; it = it->next)
00730     {
00731         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "group"))
00732         {
00733             tree->num_groups++;
00734         }
00735     }
00736 
00737     assert(tree->subtrees = calloc(tree->num_groups, sizeof(void*)));
00738     assert(tree->altlabels = calloc(tree->num_groups, sizeof(void*)));
00739     create_vector_u32(&(tree->labels), tree->num_groups);
00740 
00741     i = 0;
00742     for (it = xml_node->children; it; it = it->next)
00743     {
00744         j = 0;
00745         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model-file"))
00746         {
00747             fname = (const char*) it->children->content;
00748             len = strlen(model_dirname) + strlen(fname) + 2;
00749             fname_buf = malloc(len * sizeof(char));
00750             snprintf(fname_buf, len, "%s/%s", model_dirname, fname);
00751             tree->model_fname = fname_buf;
00752         }
00753         else if (it->type == XML_ELEMENT_NODE && 
00754                 XMLStrEqual(it->name, "labels-file"))
00755         {
00756             fname = (const char*) it->children->content;
00757             len = strlen(model_dirname) + strlen(fname) + 2;
00758             fname_buf = malloc(len * sizeof(char));
00759             snprintf(fname_buf, len, "%s/%s", model_dirname, fname);
00760             tree->labels_fname = fname_buf;
00761         }
00762         else if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "group"))
00763         {
00764             num_altlabels = 0;
00765             for (itt = it->children; itt; itt = itt->next)
00766             {
00767                 if (itt->type == XML_ELEMENT_NODE &&
00768                         XMLStrEqual(itt->name, "altlabel"))
00769                 {
00770                     num_altlabels++;
00771                 }
00772             }
00773             if (num_altlabels)
00774             {
00775                 create_vector_u32(&(tree->altlabels[ i ]), num_altlabels);
00776             }
00777 
00778             for (itt = it->children; itt; itt = itt->next)
00779             {
00780                 if (itt->type == XML_ELEMENT_NODE &&
00781                         XMLStrEqual(itt->name, "label"))
00782                 {
00783                     if ((err = lookup_haplo_group_index_from_label(&label, 
00784                                     (const char*) itt->children->content)))
00785                     {
00786                         return err;
00787                     }
00788                     tree->labels->elts[ i ] = label;
00789                 }
00790                 else if (itt->type == XML_ELEMENT_NODE &&
00791                         XMLStrEqual(itt->name, "altlabel"))
00792                 {
00793                     if ((err = lookup_haplo_group_index_from_label(&altlabel, 
00794                                     (const char*) itt->children->content)))
00795                     {
00796                         return err;
00797                     }
00798                     tree->altlabels[ i ]->elts[ j++ ] = altlabel;
00799                 }
00800                 else if (itt->type == XML_ELEMENT_NODE &&
00801                         XMLStrEqual(itt->name, "model"))
00802                 {
00803                     if ((err = create_weka_model_tree_from_xml_node(
00804                                  &(tree->subtrees[ i ]), tree, label, itt,
00805                                  model_dirname)))
00806                     {
00807                         return err;
00808                     }
00809                 }
00810             }
00811             i++;
00812         }
00813     }
00814     assert(i == tree->num_groups);
00815 
00816     return NULL;
00817 }
00818 
00819 
00824 static Error* create_model_training_data
00825 (
00826     Vector_u32**      train_labels_out,
00827     Matrix_i32**      train_markers_out,
00828     const Vector_u32* data_labels,
00829     const Matrix_i32* data_markers,
00830     const Vector_u32* model_labels,
00831     Vector_u32*const* model_altlabels
00832 )
00833 {
00834     uint8_t     b;
00835     uint32_t    i, j, k;
00836     uint32_t    n;
00837     Vector_u32* train_labels = NULL;
00838     Matrix_i32* train_markers = NULL;
00839 
00840     copy_vector_u32(&train_labels, data_labels);
00841     copy_matrix_i32(&train_markers, data_markers);
00842 
00843     n = 0;
00844     for (i = 0; i < data_labels->num_elts; i++)
00845     {
00846         for (j = 0; j < model_labels->num_elts; j++)
00847         {
00848             if (is_ancestor(data_labels->elts[ i ], model_labels->elts[ j ]))
00849             {
00850                 train_labels->elts[ n ] = model_labels->elts[ j ];
00851                 copy_matrix_block_into_matrix_i32(train_markers, n, 0,
00852                         data_markers, i, 0, 1, data_markers->num_cols);
00853                 n++;
00854                 break;
00855             }
00856             else if (model_altlabels[ j ])
00857             {
00858                 for (k = 0; k < model_altlabels[ j ]->num_elts; k++)
00859                 {
00860                     if ((b = is_ancestor(data_labels->elts[ i ], 
00861                                 model_altlabels[ j ]->elts[ k ])))
00862                     {
00863                         train_labels->elts[ n ] = model_labels->elts[ j ];
00864                         copy_matrix_block_into_matrix_i32(train_markers, n, 0,
00865                                 data_markers, i, 0, 1, data_markers->num_cols);
00866                         n++;
00867                         break;
00868                     }
00869                 }
00870                 if (b)
00871                 {
00872                     break;
00873                 }
00874             }
00875         }
00876     }
00877 
00878     if (!n)
00879     {
00880         return JWSC_EARG("No data for model");
00881     }
00882 
00883     copy_vector_section_u32(train_labels_out, train_labels, 0, n);
00884     copy_matrix_block_i32(train_markers_out, train_markers, 0, 0, n, 
00885             train_markers->num_cols);
00886 
00887     free_vector_u32(train_labels);
00888     free_matrix_i32(train_markers);
00889 
00890     return NULL;
00891 }
00892 
00893 
00895 static Error* train_weka_j48_model_node
00896 (
00897     Weka_model_node*  node,
00898     const Vector_u32* labels, 
00899     const Matrix_i32* markers,
00900     const char*       weka_jar_fname
00901 )
00902 {
00903     uint32_t i;
00904     Vector_u32* node_labels  = NULL;
00905     Matrix_i32* node_markers = NULL;
00906     Error*  err;
00907     int     slen = 256;
00908     char    str[slen];
00909 
00910     if ((err = create_model_training_data(&node_labels, &node_markers, labels,
00911                     markers, node->labels, node->altlabels)))
00912     {
00913         snprintf(str, slen, "%s: %s", node->model_fname, err->msg);
00914         return JWSC_EARG(str);
00915     }
00916 
00917     if ((err = train_weka_j48_model(node_labels, node_markers,
00918                     node->labels_fname, node->model_fname, weka_jar_fname)))
00919     {
00920         return err;
00921     }
00922 
00923     for (i = 0; i < node->num_groups; i++)
00924     {
00925         if (node->subtrees[ i ])
00926         {
00927             if ((err = train_weka_j48_model_node(node->subtrees[ i ], labels,
00928                         markers, weka_jar_fname)))
00929             {
00930                 return err;
00931             }
00932         }
00933     }
00934 
00935     free_vector_u32(node_labels);
00936     free_matrix_i32(node_markers);
00937 
00938     return NULL;
00939 }
00940 
00941 
00943 static Error* train_weka_part_model_node
00944 (
00945     Weka_model_node*  node,
00946     const Vector_u32* labels, 
00947     const Matrix_i32* markers,
00948     const char*       weka_jar_fname
00949 )
00950 {
00951     uint32_t i;
00952     Vector_u32* node_labels  = NULL;
00953     Matrix_i32* node_markers = NULL;
00954     Error*  err;
00955     int     slen = 256;
00956     char    str[slen];
00957 
00958     if ((err = create_model_training_data(&node_labels, &node_markers, labels,
00959                     markers, node->labels, node->altlabels)))
00960     {
00961         snprintf(str, slen, "%s: %s", node->model_fname, err->msg);
00962         return JWSC_EARG(str);
00963     }
00964 
00965     if ((err = train_weka_part_model(node_labels, node_markers,
00966                     node->labels_fname, node->model_fname, weka_jar_fname)))
00967     {
00968         return err;
00969     }
00970 
00971     for (i = 0; i < node->num_groups; i++)
00972     {
00973         if (node->subtrees[ i ])
00974         {
00975             if ((err = train_weka_part_model_node(node->subtrees[ i ], labels,
00976                         markers, weka_jar_fname)))
00977             {
00978                 return err;
00979             }
00980         }
00981     }
00982 
00983     free_vector_u32(node_labels);
00984     free_matrix_i32(node_markers);
00985 
00986     return NULL;
00987 }
00988 
00989 
00999 Error* train_weka_j48_model_tree
01000 (
01001     Weka_model_tree** tree_out,
01002     const Vector_u32* labels, 
01003     const Matrix_i32* markers,
01004     const char*       tree_xml_fname,
01005     const char*       tree_dtd_fname,
01006     const char*       model_dirname,
01007     const char*       weka_jar_fname
01008 )
01009 {
01010     xmlDoc*  xml_doc;
01011     xmlNode* xml_node;
01012     xmlNode* it;
01013     Error*   err;
01014     int      slen = 256;
01015     char     str[slen];
01016 
01017     assert(tree_xml_fname);
01018 
01019     if ((err = read_weka_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname)))
01020     {
01021         return err;
01022     }
01023 
01024     for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next)
01025     {
01026         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model"))
01027         {
01028             xml_node = it;
01029             break;
01030         }
01031     }
01032 
01033     if ((err = create_weka_model_tree_from_xml_node(tree_out, NULL, 0, 
01034                     xml_node, model_dirname)))
01035     {
01036         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
01037         return JWSC_EARG(str);
01038     }
01039 
01040     if ((err = train_weka_j48_model_node(*tree_out, labels, markers,
01041                     weka_jar_fname)))
01042     {
01043         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
01044         return JWSC_EARG(str);
01045     }
01046 
01047     return NULL;
01048 }
01049 
01050 
01060 Error* train_weka_part_model_tree
01061 (
01062     Weka_model_tree** tree_out,
01063     const Vector_u32* labels, 
01064     const Matrix_i32* markers,
01065     const char*       tree_xml_fname,
01066     const char*       tree_dtd_fname,
01067     const char*       model_dirname,
01068     const char*       weka_jar_fname
01069 )
01070 {
01071     xmlDoc*  xml_doc;
01072     xmlNode* xml_node;
01073     xmlNode* it;
01074     Error*   err;
01075     int      slen = 256;
01076     char     str[slen];
01077 
01078     assert(tree_xml_fname);
01079 
01080     if ((err = read_weka_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname)))
01081     {
01082         return err;
01083     }
01084 
01085     for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next)
01086     {
01087         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model"))
01088         {
01089             xml_node = it;
01090             break;
01091         }
01092     }
01093 
01094     if ((err = create_weka_model_tree_from_xml_node(tree_out, NULL, 0, 
01095                     xml_node, model_dirname)))
01096     {
01097         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
01098         return JWSC_EARG(str);
01099     }
01100 
01101     if ((err = train_weka_part_model_node(*tree_out, labels, markers,
01102                     weka_jar_fname)))
01103     {
01104         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
01105         return JWSC_EARG(str);
01106     }
01107 
01108     return NULL;
01109 }
01110 
01111 
01113 static Error* recursively_predict_j48_labels_in_model_tree
01114 (
01115     Vector_u32**           labels_out,
01116     Vector_d**             confs_in_out,
01117     const Weka_model_tree* tree,
01118     const Matrix_i32*      markers,
01119     const char*            weka_jar_fname
01120 )
01121 {
01122     uint32_t    s, num_samples;
01123     uint32_t    num_markers;
01124     uint32_t    n, N;
01125     uint32_t    i;
01126     uint32_t    subtree_label;
01127     Vector_d*   confs          = NULL;
01128     Vector_u32* sample_indices = NULL;
01129     Matrix_i32* markers_subset = NULL;
01130     Vector_u32* labels_subset  = NULL;
01131     Vector_d*   confs_subset   = NULL;
01132     Error*      err;
01133 
01134     num_samples = markers->num_rows;
01135     num_markers = markers->num_cols;
01136 
01137     assert((*confs_in_out)->num_elts == num_samples);
01138 
01139     if ((err = predict_labels_with_weka_j48_model(labels_out, &confs,
01140                     markers, tree->labels_fname, tree->model_fname,
01141                     weka_jar_fname)) != NULL)
01142     {
01143         return err;
01144     }
01145     for (s = 0; s < num_samples; s++)
01146     {
01147         (*confs_in_out)->elts[ s ] *= confs->elts[ s ];
01148     }
01149 
01150     create_vector_u32(&sample_indices, num_samples);
01151 
01152     for (i = 0; i < tree->num_groups; i++)
01153     {
01154         if (!tree->subtrees[ i ])
01155             continue;
01156 
01157         subtree_label = tree->subtrees[ i ]->parent_label;
01158 
01159         N = 0;
01160         for (s = 0; s < num_samples; s++)
01161         {
01162             if ((*labels_out)->elts[ s ] == subtree_label)
01163             {
01164                 sample_indices->elts[ N++ ] = s;
01165             }
01166         }
01167 
01168         if (!N)
01169             continue;
01170 
01171         create_matrix_i32(&markers_subset, N, num_markers);
01172         create_vector_d(&confs_subset, N);
01173 
01174         for (n = 0; n < N; n++)
01175         {
01176             s = sample_indices->elts[ n ];
01177             copy_matrix_block_into_matrix_i32(markers_subset, n, 0,
01178                     markers, s, 0, 1, num_markers);
01179             confs_subset->elts[ n ] = (*confs_in_out)->elts[ s ];
01180         }
01181 
01182         if ((err = recursively_predict_j48_labels_in_model_tree(
01183                         &labels_subset, &confs_subset, tree->subtrees[ i ], 
01184                         markers_subset, weka_jar_fname)) != NULL)
01185         {
01186             return err;
01187         }
01188 
01189         for (n = 0; n < N; n++)
01190         {
01191             s = sample_indices->elts[ n ];
01192             (*labels_out)->elts[ s ] = labels_subset->elts[ n ];
01193             (*confs_in_out)->elts[ s ] = confs_subset->elts[ n ];
01194         }
01195     }
01196 
01197     free_vector_u32(sample_indices);
01198     free_matrix_i32(markers_subset);
01199     free_vector_u32(labels_subset);
01200     free_vector_d(confs_subset);
01201 
01202     return NULL;
01203 }
01204 
01205 
01207 static Error* recursively_predict_part_labels_in_model_tree
01208 (
01209     Vector_u32**           labels_out,
01210     Vector_d**             confs_in_out,
01211     const Weka_model_tree* tree,
01212     const Matrix_i32*      markers,
01213     const char*            weka_jar_fname
01214 )
01215 {
01216     uint32_t    s, num_samples;
01217     uint32_t    num_markers;
01218     uint32_t    n, N;
01219     uint32_t    i;
01220     uint32_t    subtree_label;
01221     Vector_d*   confs          = NULL;
01222     Vector_u32* sample_indices = NULL;
01223     Matrix_i32* markers_subset = NULL;
01224     Vector_u32* labels_subset  = NULL;
01225     Vector_d*   confs_subset   = NULL;
01226     Error*      err;
01227 
01228     num_samples = markers->num_rows;
01229     num_markers = markers->num_cols;
01230 
01231     assert((*confs_in_out)->num_elts == num_samples);
01232 
01233     if ((err = predict_labels_with_weka_part_model(labels_out, &confs,
01234                     markers, tree->labels_fname, tree->model_fname,
01235                     weka_jar_fname)) != NULL)
01236     {
01237         return err;
01238     }
01239     for (s = 0; s < num_samples; s++)
01240     {
01241         (*confs_in_out)->elts[ s ] *= confs->elts[ s ];
01242     }
01243 
01244     create_vector_u32(&sample_indices, num_samples);
01245 
01246     for (i = 0; i < tree->num_groups; i++)
01247     {
01248         if (!tree->subtrees[ i ])
01249             continue;
01250 
01251         subtree_label = tree->subtrees[ i ]->parent_label;
01252 
01253         N = 0;
01254         for (s = 0; s < num_samples; s++)
01255         {
01256             if ((*labels_out)->elts[ s ] == subtree_label)
01257             {
01258                 sample_indices->elts[ N++ ] = s;
01259             }
01260         }
01261 
01262         if (!N)
01263             continue;
01264 
01265         create_matrix_i32(&markers_subset, N, num_markers);
01266         create_vector_d(&confs_subset, N);
01267 
01268         for (n = 0; n < N; n++)
01269         {
01270             s = sample_indices->elts[ n ];
01271             copy_matrix_block_into_matrix_i32(markers_subset, n, 0,
01272                     markers, s, 0, 1, num_markers);
01273             confs_subset->elts[ n ] = (*confs_in_out)->elts[ s ];
01274         }
01275 
01276         if ((err = recursively_predict_part_labels_in_model_tree(
01277                         &labels_subset, &confs_subset, tree->subtrees[ i ], 
01278                         markers_subset, weka_jar_fname)) != NULL)
01279         {
01280             return err;
01281         }
01282 
01283         for (n = 0; n < N; n++)
01284         {
01285             s = sample_indices->elts[ n ];
01286             (*labels_out)->elts[ s ] = labels_subset->elts[ n ];
01287             (*confs_in_out)->elts[ s ] = confs_subset->elts[ n ];
01288         }
01289     }
01290 
01291     free_vector_u32(sample_indices);
01292     free_matrix_i32(markers_subset);
01293     free_vector_u32(labels_subset);
01294     free_vector_d(confs_subset);
01295 
01296     return NULL;
01297 }
01298 
01299 
01314 Error* predict_labels_with_weka_j48_model_tree
01315 (
01316     Vector_u32**           labels_out,
01317     Vector_d**             confs_out,
01318     const Matrix_i32*      markers,
01319     const Weka_model_tree* tree,
01320     const char*            weka_jar_fname
01321 )
01322 {
01323     Error* err;
01324 
01325     create_init_vector_d(confs_out, markers->num_rows, 1.0);
01326 
01327     if ((err = recursively_predict_j48_labels_in_model_tree(labels_out,
01328                     confs_out, tree, markers, weka_jar_fname)) != NULL)
01329     {
01330         return err;
01331     }
01332 
01333     return NULL;
01334 }
01335 
01336 
01351 Error* predict_labels_with_weka_part_model_tree
01352 (
01353     Vector_u32**           labels_out,
01354     Vector_d**             confs_out,
01355     const Matrix_i32*      markers,
01356     const Weka_model_tree* tree,
01357     const char*            weka_jar_fname
01358 )
01359 {
01360     Error* err;
01361 
01362     create_init_vector_d(confs_out, markers->num_rows, 1.0);
01363 
01364     if ((err = recursively_predict_part_labels_in_model_tree(labels_out,
01365                     confs_out, tree, markers, weka_jar_fname)) != NULL)
01366     {
01367         return err;
01368     }
01369 
01370     return NULL;
01371 }
01372 
01373 
01380 Error* read_weka_model_tree
01381 (
01382     Weka_model_tree** tree_out,
01383     const char*       tree_xml_fname,
01384     const char*       tree_dtd_fname,
01385     const char*       model_dirname
01386 )
01387 {
01388     xmlDoc*  xml_doc;
01389     xmlNode* xml_node;
01390     xmlNode* it;
01391     Error*   err;
01392     int      slen = 256;
01393     char     str[slen];
01394 
01395     assert(tree_xml_fname && model_dirname);
01396 
01397     if ((err = read_weka_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname)))
01398     {
01399         return err;
01400     }
01401 
01402     for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next)
01403     {
01404         if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model"))
01405         {
01406             xml_node = it;
01407             break;
01408         }
01409     }
01410 
01411     if ((err = create_weka_model_tree_from_xml_node(tree_out, NULL, 0,
01412                     xml_node, model_dirname)))
01413     {
01414         snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg);
01415         return JWSC_EARG(str);
01416     }
01417 
01418     return NULL;
01419 }
01420 
01421 
01423 void free_weka_model_tree(Weka_model_tree* tree)
01424 {
01425     uint32_t i;
01426 
01427     if (!tree)
01428         return;
01429 
01430     for (i = 0; i < tree->num_groups; i++)
01431     {
01432         free_weka_model_tree(tree->subtrees[ i ]);
01433         free_vector_u32(tree->altlabels[ i ]);
01434     }
01435 
01436     free(tree->subtrees);
01437     free(tree->altlabels);
01438     free_vector_u32(tree->labels);
01439     free(tree->labels_fname);
01440     free(tree->model_fname);
01441     free(tree);
01442 }