Haplo Prediction
predict haplogroups
|
00001 /* 00002 * This work is licensed under a Creative Commons 00003 * Attribution-Noncommercial-Share Alike 3.0 United States License. 00004 * 00005 * http://creativecommons.org/licenses/by-nc-sa/3.0/us/ 00006 * 00007 * You are free: 00008 * 00009 * to Share - to copy, distribute, display, and perform the work 00010 * to Remix - to make derivative works 00011 * 00012 * Under the following conditions: 00013 * 00014 * Attribution. You must attribute the work in the manner specified by the 00015 * author or licensor (but not in any way that suggests that they endorse you 00016 * or your use of the work). 00017 * 00018 * Noncommercial. You may not use this work for commercial purposes. 00019 * 00020 * Share Alike. If you alter, transform, or build upon this work, you may 00021 * distribute the resulting work only under the same or similar license to 00022 * this one. 00023 * 00024 * For any reuse or distribution, you must make clear to others the license 00025 * terms of this work. The best way to do this is by including this header. 00026 * 00027 * Any of the above conditions can be waived if you get permission from the 00028 * copyright holder. 00029 * 00030 * Apart from the remix rights granted under this license, nothing in this 00031 * license impairs or restricts the author's moral rights. 00032 */ 00033 00034 00046 #include <config.h> 00047 00048 #include <stdlib.h> 00049 #include <stdio.h> 00050 #include <assert.h> 00051 #include <inttypes.h> 00052 #include <unistd.h> 00053 #include <string.h> 00054 #include <errno.h> 00055 00056 #include <libxml/tree.h> 00057 #include <libxml/parser.h> 00058 #include <libxml/valid.h> 00059 00060 #ifdef HAPLO_HAVE_DMALLOC 00061 #include <dmalloc.h> 00062 #endif 00063 00064 #include <jwsc/base/error.h> 00065 #include <jwsc/base/file_io.h> 00066 #include <jwsc/vector/vector.h> 00067 #include <jwsc/matrix/matrix.h> 00068 00069 #include "xml.h" 00070 #include "haplo_groups.h" 00071 #include "weka.h" 00072 00073 00075 static Error* read_weka_model_labels 00076 ( 00077 char** labels_str_out, 00078 const char* labels_fname 00079 ) 00080 { 00081 static char buf[4096] = {0}; 00082 char* buf_ptr; 00083 char c; 00084 FILE* fp; 00085 int slen = 1024; 00086 char str[slen]; 00087 00088 if ((fp = fopen(labels_fname, "r")) == NULL) 00089 { 00090 snprintf(str, slen, "%s: %s", labels_fname, strerror(errno)); 00091 return JWSC_EIO(str); 00092 } 00093 00094 if (!skip_fp_spaces_and_comments(fp)) 00095 { 00096 snprintf(str, slen, "%s: %s", labels_fname, 00097 "Labels file not formatted properly"); 00098 return JWSC_EARG(str); 00099 } 00100 00101 buf_ptr = buf; 00102 while ((c = fgetc(fp)) != EOF && c != '\n' && buf_ptr != (buf+4096)) 00103 { 00104 *buf_ptr = c; 00105 buf_ptr++; 00106 } 00107 if (buf_ptr != (buf+4096)) 00108 { 00109 *buf_ptr = 0; 00110 } 00111 00112 if (*labels_str_out) 00113 { 00114 free(*labels_str_out); 00115 } 00116 *labels_str_out = malloc((strlen(buf)+1)*sizeof(char)); 00117 strcpy(*labels_str_out, buf); 00118 00119 if (fclose(fp) != 0) 00120 { 00121 snprintf(str, slen, "%s: %s", labels_fname, strerror(errno)); 00122 return JWSC_EIO(str); 00123 } 00124 00125 return NULL; 00126 } 00127 00128 00140 Error* train_weka_j48_model 00141 ( 00142 const Vector_u32* labels, 00143 const Matrix_i32* markers, 00144 const char* labels_fname, 00145 const char* model_fname, 00146 const char* weka_jar_fname 00147 ) 00148 { 00149 static char script[4096] = {0}; 00150 static char tmp_data_csv_fname[256] = {0}; 00151 const char* label_str; 00152 00153 FILE* fp; 00154 uint32_t pid; 00155 uint32_t s, num_samples; 00156 uint32_t m, num_markers; 00157 int slen = 1024; 00158 char str[slen]; 00159 00160 pid = getpid(); 00161 00162 snprintf(tmp_data_csv_fname, 255, "/tmp/.weka_j48_train_data_csv-%u", pid); 00163 if (!(fp = fopen(tmp_data_csv_fname, "w"))) 00164 { 00165 snprintf(str, slen, "%s: %s", tmp_data_csv_fname, strerror(errno)); 00166 return JWSC_EIO(str); 00167 } 00168 00169 num_samples = markers->num_rows; 00170 num_markers = markers->num_cols; 00171 00172 for (s = 0; s < num_samples; s++) 00173 { 00174 assert(lookup_haplo_group_label_from_index(&label_str, 00175 labels->elts[ s ]) == NULL); 00176 fprintf(fp, "%s", label_str); 00177 for (m = 0; m < num_markers; m++) 00178 { 00179 fprintf(fp, ",%u", markers->elts[ s ][ m ]); 00180 } 00181 fprintf(fp, "\n"); 00182 } 00183 00184 if (fclose(fp) != 0) 00185 { 00186 snprintf(str, slen, "%s: %s", tmp_data_csv_fname, strerror(errno)); 00187 return JWSC_EIO(str); 00188 } 00189 00190 snprintf(script, 4096, 00191 "WEKA_JAR=\"%s\"; \ 00192 JAVA=\"%s\"; \ 00193 EGREP=\"%s\"; \ 00194 AWK=\"%s\"; \ 00195 SED=\"%s\"; \ 00196 TMP_TRAIN_DATA_ARFF=\"/tmp/.weka_j48_train_data_arff-%u\"; \ 00197 TMP_TRAIN_DATA_CSV=\"%s\"; \ 00198 JVM_OPTS=\"-Xmx1024m -cp $WEKA_JAR\"; \ 00199 CSV_CONVERTER=\"weka.core.converters.CSVLoader\"; \ 00200 J48=\"weka.classifiers.trees.J48\"; \ 00201 J48_OPTS=\"-v -t $TMP_TRAIN_DATA_ARFF -x 2 -c 1 -d %s\"; \ 00202 $JAVA $JVM_OPTS $CSV_CONVERTER $TMP_TRAIN_DATA_CSV \ 00203 > $TMP_TRAIN_DATA_ARFF; \ 00204 $EGREP '@attribute' $TMP_TRAIN_DATA_ARFF \ 00205 | $SED '2,$d' \ 00206 | $AWK '{for (i=3;i<=NF;i++) printf \"%%s \", $i; printf \"\\n\"}' \ 00207 > %s; \ 00208 $JAVA $JVM_OPTS $J48 $J48_OPTS >> /dev/null; \ 00209 rm -f $TMP_TRAIN_DATA_ARFF; \ 00210 rm -f $TMP_TRAIN_DATA_CSV", 00211 weka_jar_fname, HAPLO_JAVA, HAPLO_EGREP, HAPLO_AWK, HAPLO_SED, 00212 pid, tmp_data_csv_fname, model_fname, labels_fname); 00213 00214 system(script); 00215 00216 return NULL; 00217 } 00218 00219 00231 Error* train_weka_part_model 00232 ( 00233 const Vector_u32* labels, 00234 const Matrix_i32* markers, 00235 const char* labels_fname, 00236 const char* model_fname, 00237 const char* weka_jar_fname 00238 ) 00239 { 00240 static char script[4096] = {0}; 00241 static char tmp_data_csv_fname[256] = {0}; 00242 const char* label_str; 00243 00244 FILE* fp; 00245 uint32_t pid; 00246 uint32_t s, num_samples; 00247 uint32_t m, num_markers; 00248 int slen = 1024; 00249 char str[slen]; 00250 00251 pid = getpid(); 00252 00253 snprintf(tmp_data_csv_fname, 255, "/tmp/.weka_part_train_data_csv-%u", pid); 00254 if (!(fp = fopen(tmp_data_csv_fname, "w"))) 00255 { 00256 snprintf(str, slen, "%s: %s", tmp_data_csv_fname, 00257 strerror(errno)); 00258 return JWSC_EIO(str); 00259 } 00260 00261 num_samples = markers->num_rows; 00262 num_markers = markers->num_cols; 00263 00264 for (s = 0; s < num_samples; s++) 00265 { 00266 assert(lookup_haplo_group_label_from_index(&label_str, 00267 labels->elts[ s ]) == NULL); 00268 fprintf(fp, "%s", label_str); 00269 for (m = 0; m < num_markers; m++) 00270 { 00271 fprintf(fp, ",%u", markers->elts[ s ][ m ]); 00272 } 00273 fprintf(fp, "\n"); 00274 } 00275 00276 if (fclose(fp) != 0) 00277 { 00278 snprintf(str, slen, "%s: %s", tmp_data_csv_fname, strerror(errno)); 00279 return JWSC_EIO(str); 00280 } 00281 00282 snprintf(script, 4096, 00283 "WEKA_JAR=\"%s\"; \ 00284 JAVA=\"%s\"; \ 00285 EGREP=\"%s\"; \ 00286 AWK=\"%s\"; \ 00287 SED=\"%s\"; \ 00288 TMP_TRAIN_DATA_ARFF=\"/tmp/.weka_part_train_data_arff-%u\"; \ 00289 TMP_TRAIN_DATA_CSV=\"%s\"; \ 00290 JVM_OPTS=\"-Xmx1024m -cp $WEKA_JAR\"; \ 00291 CSV_CONVERTER=\"weka.core.converters.CSVLoader\"; \ 00292 PART=\"weka.classifiers.rules.PART\"; \ 00293 PART_OPTS=\"-v -t $TMP_TRAIN_DATA_ARFF -x 2 -c 1 -d %s\"; \ 00294 $JAVA $JVM_OPTS $CSV_CONVERTER $TMP_TRAIN_DATA_CSV \ 00295 > $TMP_TRAIN_DATA_ARFF; \ 00296 $EGREP '@attribute' $TMP_TRAIN_DATA_ARFF \ 00297 | $SED '2,$d' \ 00298 | $AWK '{for (i=3;i<=NF;i++) printf \"%%s \", $i; printf \"\\n\"}' \ 00299 > %s; \ 00300 $JAVA $JVM_OPTS $PART $PART_OPTS >> /dev/null; \ 00301 rm -f $TMP_TRAIN_DATA_ARFF; \ 00302 rm -f $TMP_TRAIN_DATA_CSV", 00303 weka_jar_fname, HAPLO_JAVA, HAPLO_EGREP, HAPLO_AWK, HAPLO_SED, 00304 pid, tmp_data_csv_fname, model_fname, labels_fname); 00305 00306 system(script); 00307 00308 return NULL; 00309 } 00310 00311 00322 Error* predict_labels_with_weka_j48_model 00323 ( 00324 Vector_u32** labels_out, 00325 Vector_d** confs_out, 00326 const Matrix_i32* markers, 00327 const char* labels_fname, 00328 const char* model_fname, 00329 const char* weka_jar_fname 00330 ) 00331 { 00332 static char script[4096] = {0}; 00333 static char tmp_data_fname[256] = {0}; 00334 static char tmp_result_fname[256] = {0}; 00335 char label_buf[256] = {0}; 00336 char* label_buf_ptr; 00337 char c; 00338 00339 FILE* fp; 00340 uint32_t pid; 00341 uint32_t s, num_samples; 00342 uint32_t m, num_markers; 00343 uint32_t label; 00344 double conf; 00345 char* labels_str = NULL; 00346 Error* err; 00347 int slen = 1024; 00348 char str[slen]; 00349 00350 if ((err = read_weka_model_labels(&labels_str, labels_fname))) 00351 { 00352 return err; 00353 } 00354 00355 pid = getpid(); 00356 num_samples = markers->num_rows; 00357 num_markers = markers->num_cols; 00358 00359 create_vector_u32(labels_out, num_samples); 00360 create_vector_d(confs_out, num_samples); 00361 00362 snprintf(tmp_data_fname, 256, "/tmp/.weka_j48_predict_data-%u", pid); 00363 snprintf(tmp_result_fname, 256, "/tmp/.weka_j48_predict_result-%u", pid); 00364 00365 if ((fp = fopen(tmp_data_fname, "w")) == NULL) 00366 { 00367 snprintf(str, slen, "%s: %s", tmp_data_fname, strerror(errno)); 00368 return JWSC_EIO(str); 00369 } 00370 00371 fprintf(fp, "@relation %s\n\n", tmp_data_fname); 00372 00373 fprintf(fp, "@attribute label %s\n", labels_str); 00374 00375 for (m = 0; m < num_markers; m++) 00376 { 00377 fprintf(fp, "@attribute %u numeric\n", m); 00378 } 00379 00380 fprintf(fp, "\n@data\n"); 00381 for (s = 0; s < num_samples; s++) 00382 { 00383 fprintf(fp, "?"); 00384 for (m = 0; m < num_markers; m++) 00385 { 00386 if (markers->elts[ s ][ m ] > 0) 00387 { 00388 fprintf(fp, ",%u", markers->elts[ s ][ m ]); 00389 } 00390 else 00391 { 00392 fprintf(fp, ",?"); 00393 } 00394 } 00395 fprintf(fp, "\n"); 00396 } 00397 00398 if (fclose(fp) != 0) 00399 { 00400 snprintf(str, slen, "%s: %s", tmp_data_fname, strerror(errno)); 00401 return JWSC_EIO(str); 00402 } 00403 00404 snprintf(script, 4096, 00405 "WEKA_JAR=\"%s\"; \ 00406 JAVA=\"%s\"; \ 00407 AWK=\"%s\"; \ 00408 SED=\"%s\"; \ 00409 TMP_DATA=\"%s\"; \ 00410 TMP_RESULT=\"%s\"; \ 00411 JVM_OPTS=\"-Xmx1024m -cp $WEKA_JAR\"; \ 00412 J48=\"weka.classifiers.trees.J48\"; \ 00413 J48_OPTS=\"-c 1 -T $TMP_DATA -l %s -p 0\"; \ 00414 $JAVA $JVM_OPTS $J48 $J48_OPTS \ 00415 | $SED '$d' \ 00416 | $SED 's,^[^ ]\\+ \\(.*\\) \\([^ ]\\+\\) [^ ]\\+[ ]*$,\\1\\n\\2,' \ 00417 > $TMP_RESULT", 00418 weka_jar_fname, HAPLO_JAVA, HAPLO_AWK, HAPLO_SED, tmp_data_fname, 00419 tmp_result_fname, model_fname); 00420 system(script); 00421 00422 if (!(fp = fopen(tmp_result_fname, "r"))) 00423 { 00424 snprintf(str, slen, "%s: %s", tmp_result_fname, strerror(errno)); 00425 return JWSC_EIO(str); 00426 } 00427 00428 for (s = 0; s < num_samples; s++) 00429 { 00430 label_buf_ptr = label_buf; 00431 while ((c = fgetc(fp)) != EOF && c != '\n' && 00432 label_buf_ptr != (label_buf+256)) 00433 { 00434 *label_buf_ptr = c; 00435 label_buf_ptr++; 00436 } 00437 if (label_buf_ptr != (label_buf+256)) 00438 { 00439 *label_buf_ptr = 0; 00440 } 00441 00442 if (fscanf(fp, "%lf\n", &conf) != 1) 00443 { 00444 snprintf(str, slen, "%s: Sample %d: %s", tmp_result_fname, 00445 s+1, "Invalid output from Weka J48"); 00446 return JWSC_EIO(str); 00447 } 00448 if ((err = lookup_haplo_group_index_from_label(&label, label_buf))) 00449 { 00450 snprintf(str, slen, "%s: Sample %d: %s", tmp_result_fname, s+1, 00451 err->msg); 00452 return JWSC_EIO(str); 00453 } 00454 00455 (*labels_out)->elts[ s ] = label; 00456 (*confs_out)->elts[ s ] = conf; 00457 } 00458 00459 if (fclose(fp) != 0) 00460 { 00461 snprintf(str, slen, "%s: %s", tmp_result_fname, strerror(errno)); 00462 return JWSC_EIO(str); 00463 } 00464 00465 snprintf(script, 4096, "rm -f %s", tmp_data_fname); 00466 system(script); 00467 snprintf(script, 4096, "rm -f %s", tmp_result_fname); 00468 system(script); 00469 00470 free(labels_str); 00471 00472 return NULL; 00473 } 00474 00475 00486 Error* predict_labels_with_weka_part_model 00487 ( 00488 Vector_u32** labels_out, 00489 Vector_d** confs_out, 00490 const Matrix_i32* markers, 00491 const char* labels_fname, 00492 const char* model_fname, 00493 const char* weka_jar_fname 00494 ) 00495 { 00496 static char script[4096] = {0}; 00497 static char tmp_data_fname[256] = {0}; 00498 static char tmp_result_fname[256] = {0}; 00499 char label_buf[256] = {0}; 00500 char* label_buf_ptr; 00501 char c; 00502 00503 FILE* fp; 00504 uint32_t pid; 00505 uint32_t s, num_samples; 00506 uint32_t m, num_markers; 00507 uint32_t label; 00508 double conf; 00509 char* labels_str = NULL; 00510 Error* err; 00511 int slen = 1024; 00512 char str[slen]; 00513 00514 if ((err = read_weka_model_labels(&labels_str, labels_fname))) 00515 { 00516 return err; 00517 } 00518 00519 pid = getpid(); 00520 num_samples = markers->num_rows; 00521 num_markers = markers->num_cols; 00522 00523 create_vector_u32(labels_out, num_samples); 00524 create_vector_d(confs_out, num_samples); 00525 00526 snprintf(tmp_data_fname, 256, "/tmp/.weka_part_predict_data-%u", pid); 00527 snprintf(tmp_result_fname, 256, "/tmp/.weka_part_predict_result-%u", pid); 00528 00529 if ((fp = fopen(tmp_data_fname, "w")) == NULL) 00530 { 00531 snprintf(str, slen, "%s: %s", tmp_data_fname, strerror(errno)); 00532 return JWSC_EIO(str); 00533 } 00534 00535 fprintf(fp, "@relation %s\n\n", tmp_data_fname); 00536 00537 fprintf(fp, "@attribute label %s\n", labels_str); 00538 00539 for (m = 0; m < num_markers; m++) 00540 { 00541 fprintf(fp, "@attribute %u numeric\n", m); 00542 } 00543 00544 fprintf(fp, "\n@data\n"); 00545 for (s = 0; s < num_samples; s++) 00546 { 00547 fprintf(fp, "?"); 00548 for (m = 0; m < num_markers; m++) 00549 { 00550 if (markers->elts[ s ][ m ] > 0) 00551 { 00552 fprintf(fp, ",%u", markers->elts[ s ][ m ]); 00553 } 00554 else 00555 { 00556 fprintf(fp, ",?"); 00557 } 00558 } 00559 fprintf(fp, "\n"); 00560 } 00561 00562 if (fclose(fp) != 0) 00563 { 00564 snprintf(str, slen, "%s: %s", tmp_data_fname, strerror(errno)); 00565 return JWSC_EIO(str); 00566 } 00567 00568 snprintf(script, 4096, 00569 "WEKA_JAR=\"%s\"; \ 00570 JAVA=\"%s\"; \ 00571 AWK=\"%s\"; \ 00572 SED=\"%s\"; \ 00573 TMP_DATA=\"%s\"; \ 00574 TMP_RESULT=\"%s\"; \ 00575 JVM_OPTS=\"-Xmx1024m -cp $WEKA_JAR\"; \ 00576 PART=\"weka.classifiers.rules.PART\"; \ 00577 PART_OPTS=\"-c 1 -T $TMP_DATA -l %s -p 0\"; \ 00578 $JAVA $JVM_OPTS $PART $PART_OPTS \ 00579 | $SED '$d' \ 00580 | $SED 's,^[^ ]\\+ \\(.*\\) \\([^ ]\\+\\) [^ ]\\+[ ]*$,\\1\\n\\2,' \ 00581 > $TMP_RESULT", 00582 weka_jar_fname, HAPLO_JAVA, HAPLO_AWK, HAPLO_SED, tmp_data_fname, 00583 tmp_result_fname, model_fname); 00584 system(script); 00585 00586 if (!(fp = fopen(tmp_result_fname, "r"))) 00587 { 00588 snprintf(str, slen, "%s: %s", tmp_result_fname, strerror(errno)); 00589 return JWSC_EIO(str); 00590 } 00591 00592 for (s = 0; s < num_samples; s++) 00593 { 00594 label_buf_ptr = label_buf; 00595 while ((c = fgetc(fp)) != EOF && c != '\n' && 00596 label_buf_ptr != (label_buf+256)) 00597 { 00598 *label_buf_ptr = c; 00599 label_buf_ptr++; 00600 } 00601 if (label_buf_ptr != (label_buf+256)) 00602 { 00603 *label_buf_ptr = 0; 00604 } 00605 00606 if (fscanf(fp, "%lf\n", &conf) != 1) 00607 { 00608 snprintf(str, slen, "%s: Sample %d: %s", tmp_result_fname, 00609 s+1, "Invalid output from Weka PART"); 00610 return JWSC_EIO(str); 00611 } 00612 if ((err = lookup_haplo_group_index_from_label(&label, label_buf))) 00613 { 00614 snprintf(str, slen, "%s: Sample %d: %s", tmp_result_fname, 00615 s+1, err->msg); 00616 return JWSC_EIO(str); 00617 } 00618 00619 (*labels_out)->elts[ s ] = label; 00620 (*confs_out)->elts[ s ] = conf; 00621 } 00622 00623 if (fclose(fp) != 0) 00624 { 00625 snprintf(str, slen, "%s: %s", tmp_result_fname, strerror(errno)); 00626 return JWSC_EIO(str); 00627 } 00628 00629 snprintf(script, 4096, "rm -f %s", tmp_data_fname); 00630 system(script); 00631 snprintf(script, 4096, "rm -f %s", tmp_result_fname); 00632 system(script); 00633 00634 free(labels_str); 00635 00636 return NULL; 00637 } 00638 00639 00641 static Error* read_weka_xml_doc 00642 ( 00643 xmlDoc** xml_doc_out, 00644 const char* xml_fname, 00645 const char* dtd_fname 00646 ) 00647 { 00648 xmlParserCtxt* xml_parse_ctxt; 00649 xmlValidCtxt* xml_valid_ctxt; 00650 xmlDtd* xml_dtd; 00651 int slen = 256; 00652 char str[slen]; 00653 00654 assert(xml_parse_ctxt = xmlNewParserCtxt()); 00655 00656 if (!(*xml_doc_out = xmlCtxtReadFile(xml_parse_ctxt, xml_fname, NULL, 0))) 00657 { 00658 snprintf(str, slen, "%s: %s", xml_fname, "Could not parse file"); 00659 return JWSC_EARG(str); 00660 } 00661 00662 xmlFreeParserCtxt(xml_parse_ctxt); 00663 00664 if (dtd_fname) 00665 { 00666 assert(xml_valid_ctxt = xmlNewValidCtxt()); 00667 00668 if (!(xml_dtd = xmlParseDTD(NULL, (xmlChar*)dtd_fname))) 00669 { 00670 snprintf(str, slen, "%s: %s", dtd_fname, "Could not parse DTD"); 00671 return JWSC_EARG(str); 00672 } 00673 00674 if (!xmlValidateDtd(xml_valid_ctxt, *xml_doc_out, xml_dtd)) 00675 { 00676 snprintf(str, slen, "%s: %s", xml_fname, "XML file not valid"); 00677 return JWSC_EARG(str); 00678 } 00679 00680 xmlFreeValidCtxt(xml_valid_ctxt); 00681 xmlFreeDtd(xml_dtd); 00682 } 00683 00684 return NULL; 00685 } 00686 00687 00689 static Error* create_weka_model_tree_from_xml_node 00690 ( 00691 Weka_model_tree** tree_out, 00692 const Weka_model_tree* parent, 00693 uint32_t parent_label, 00694 xmlNode* xml_node, 00695 const char* model_dirname 00696 ) 00697 { 00698 Weka_model_tree* tree; 00699 uint32_t i, j; 00700 uint32_t len; 00701 uint32_t label; 00702 uint32_t altlabel; 00703 uint32_t num_altlabels; 00704 xmlNode* it; 00705 xmlNode* itt; 00706 Error* err; 00707 const char* fname; 00708 char* fname_buf; 00709 00710 if (*tree_out) 00711 { 00712 free_weka_model_tree(*tree_out); 00713 } 00714 00715 assert(*tree_out = malloc(sizeof(Weka_model_tree))); 00716 tree = *tree_out; 00717 00718 tree->parent = parent; 00719 tree->parent_label = parent_label; 00720 tree->num_groups = 0; 00721 tree->subtrees = NULL; 00722 tree->labels = NULL; 00723 tree->altlabels = NULL; 00724 tree->labels_fname = NULL; 00725 tree->model_fname = NULL; 00726 00727 assert(XMLStrEqual(xml_node->name, "model")); 00728 00729 for (it = xml_node->children; it; it = it->next) 00730 { 00731 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "group")) 00732 { 00733 tree->num_groups++; 00734 } 00735 } 00736 00737 assert(tree->subtrees = calloc(tree->num_groups, sizeof(void*))); 00738 assert(tree->altlabels = calloc(tree->num_groups, sizeof(void*))); 00739 create_vector_u32(&(tree->labels), tree->num_groups); 00740 00741 i = 0; 00742 for (it = xml_node->children; it; it = it->next) 00743 { 00744 j = 0; 00745 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model-file")) 00746 { 00747 fname = (const char*) it->children->content; 00748 len = strlen(model_dirname) + strlen(fname) + 2; 00749 fname_buf = malloc(len * sizeof(char)); 00750 snprintf(fname_buf, len, "%s/%s", model_dirname, fname); 00751 tree->model_fname = fname_buf; 00752 } 00753 else if (it->type == XML_ELEMENT_NODE && 00754 XMLStrEqual(it->name, "labels-file")) 00755 { 00756 fname = (const char*) it->children->content; 00757 len = strlen(model_dirname) + strlen(fname) + 2; 00758 fname_buf = malloc(len * sizeof(char)); 00759 snprintf(fname_buf, len, "%s/%s", model_dirname, fname); 00760 tree->labels_fname = fname_buf; 00761 } 00762 else if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "group")) 00763 { 00764 num_altlabels = 0; 00765 for (itt = it->children; itt; itt = itt->next) 00766 { 00767 if (itt->type == XML_ELEMENT_NODE && 00768 XMLStrEqual(itt->name, "altlabel")) 00769 { 00770 num_altlabels++; 00771 } 00772 } 00773 if (num_altlabels) 00774 { 00775 create_vector_u32(&(tree->altlabels[ i ]), num_altlabels); 00776 } 00777 00778 for (itt = it->children; itt; itt = itt->next) 00779 { 00780 if (itt->type == XML_ELEMENT_NODE && 00781 XMLStrEqual(itt->name, "label")) 00782 { 00783 if ((err = lookup_haplo_group_index_from_label(&label, 00784 (const char*) itt->children->content))) 00785 { 00786 return err; 00787 } 00788 tree->labels->elts[ i ] = label; 00789 } 00790 else if (itt->type == XML_ELEMENT_NODE && 00791 XMLStrEqual(itt->name, "altlabel")) 00792 { 00793 if ((err = lookup_haplo_group_index_from_label(&altlabel, 00794 (const char*) itt->children->content))) 00795 { 00796 return err; 00797 } 00798 tree->altlabels[ i ]->elts[ j++ ] = altlabel; 00799 } 00800 else if (itt->type == XML_ELEMENT_NODE && 00801 XMLStrEqual(itt->name, "model")) 00802 { 00803 if ((err = create_weka_model_tree_from_xml_node( 00804 &(tree->subtrees[ i ]), tree, label, itt, 00805 model_dirname))) 00806 { 00807 return err; 00808 } 00809 } 00810 } 00811 i++; 00812 } 00813 } 00814 assert(i == tree->num_groups); 00815 00816 return NULL; 00817 } 00818 00819 00824 static Error* create_model_training_data 00825 ( 00826 Vector_u32** train_labels_out, 00827 Matrix_i32** train_markers_out, 00828 const Vector_u32* data_labels, 00829 const Matrix_i32* data_markers, 00830 const Vector_u32* model_labels, 00831 Vector_u32*const* model_altlabels 00832 ) 00833 { 00834 uint8_t b; 00835 uint32_t i, j, k; 00836 uint32_t n; 00837 Vector_u32* train_labels = NULL; 00838 Matrix_i32* train_markers = NULL; 00839 00840 copy_vector_u32(&train_labels, data_labels); 00841 copy_matrix_i32(&train_markers, data_markers); 00842 00843 n = 0; 00844 for (i = 0; i < data_labels->num_elts; i++) 00845 { 00846 for (j = 0; j < model_labels->num_elts; j++) 00847 { 00848 if (is_ancestor(data_labels->elts[ i ], model_labels->elts[ j ])) 00849 { 00850 train_labels->elts[ n ] = model_labels->elts[ j ]; 00851 copy_matrix_block_into_matrix_i32(train_markers, n, 0, 00852 data_markers, i, 0, 1, data_markers->num_cols); 00853 n++; 00854 break; 00855 } 00856 else if (model_altlabels[ j ]) 00857 { 00858 for (k = 0; k < model_altlabels[ j ]->num_elts; k++) 00859 { 00860 if ((b = is_ancestor(data_labels->elts[ i ], 00861 model_altlabels[ j ]->elts[ k ]))) 00862 { 00863 train_labels->elts[ n ] = model_labels->elts[ j ]; 00864 copy_matrix_block_into_matrix_i32(train_markers, n, 0, 00865 data_markers, i, 0, 1, data_markers->num_cols); 00866 n++; 00867 break; 00868 } 00869 } 00870 if (b) 00871 { 00872 break; 00873 } 00874 } 00875 } 00876 } 00877 00878 if (!n) 00879 { 00880 return JWSC_EARG("No data for model"); 00881 } 00882 00883 copy_vector_section_u32(train_labels_out, train_labels, 0, n); 00884 copy_matrix_block_i32(train_markers_out, train_markers, 0, 0, n, 00885 train_markers->num_cols); 00886 00887 free_vector_u32(train_labels); 00888 free_matrix_i32(train_markers); 00889 00890 return NULL; 00891 } 00892 00893 00895 static Error* train_weka_j48_model_node 00896 ( 00897 Weka_model_node* node, 00898 const Vector_u32* labels, 00899 const Matrix_i32* markers, 00900 const char* weka_jar_fname 00901 ) 00902 { 00903 uint32_t i; 00904 Vector_u32* node_labels = NULL; 00905 Matrix_i32* node_markers = NULL; 00906 Error* err; 00907 int slen = 256; 00908 char str[slen]; 00909 00910 if ((err = create_model_training_data(&node_labels, &node_markers, labels, 00911 markers, node->labels, node->altlabels))) 00912 { 00913 snprintf(str, slen, "%s: %s", node->model_fname, err->msg); 00914 return JWSC_EARG(str); 00915 } 00916 00917 if ((err = train_weka_j48_model(node_labels, node_markers, 00918 node->labels_fname, node->model_fname, weka_jar_fname))) 00919 { 00920 return err; 00921 } 00922 00923 for (i = 0; i < node->num_groups; i++) 00924 { 00925 if (node->subtrees[ i ]) 00926 { 00927 if ((err = train_weka_j48_model_node(node->subtrees[ i ], labels, 00928 markers, weka_jar_fname))) 00929 { 00930 return err; 00931 } 00932 } 00933 } 00934 00935 free_vector_u32(node_labels); 00936 free_matrix_i32(node_markers); 00937 00938 return NULL; 00939 } 00940 00941 00943 static Error* train_weka_part_model_node 00944 ( 00945 Weka_model_node* node, 00946 const Vector_u32* labels, 00947 const Matrix_i32* markers, 00948 const char* weka_jar_fname 00949 ) 00950 { 00951 uint32_t i; 00952 Vector_u32* node_labels = NULL; 00953 Matrix_i32* node_markers = NULL; 00954 Error* err; 00955 int slen = 256; 00956 char str[slen]; 00957 00958 if ((err = create_model_training_data(&node_labels, &node_markers, labels, 00959 markers, node->labels, node->altlabels))) 00960 { 00961 snprintf(str, slen, "%s: %s", node->model_fname, err->msg); 00962 return JWSC_EARG(str); 00963 } 00964 00965 if ((err = train_weka_part_model(node_labels, node_markers, 00966 node->labels_fname, node->model_fname, weka_jar_fname))) 00967 { 00968 return err; 00969 } 00970 00971 for (i = 0; i < node->num_groups; i++) 00972 { 00973 if (node->subtrees[ i ]) 00974 { 00975 if ((err = train_weka_part_model_node(node->subtrees[ i ], labels, 00976 markers, weka_jar_fname))) 00977 { 00978 return err; 00979 } 00980 } 00981 } 00982 00983 free_vector_u32(node_labels); 00984 free_matrix_i32(node_markers); 00985 00986 return NULL; 00987 } 00988 00989 00999 Error* train_weka_j48_model_tree 01000 ( 01001 Weka_model_tree** tree_out, 01002 const Vector_u32* labels, 01003 const Matrix_i32* markers, 01004 const char* tree_xml_fname, 01005 const char* tree_dtd_fname, 01006 const char* model_dirname, 01007 const char* weka_jar_fname 01008 ) 01009 { 01010 xmlDoc* xml_doc; 01011 xmlNode* xml_node; 01012 xmlNode* it; 01013 Error* err; 01014 int slen = 256; 01015 char str[slen]; 01016 01017 assert(tree_xml_fname); 01018 01019 if ((err = read_weka_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname))) 01020 { 01021 return err; 01022 } 01023 01024 for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next) 01025 { 01026 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model")) 01027 { 01028 xml_node = it; 01029 break; 01030 } 01031 } 01032 01033 if ((err = create_weka_model_tree_from_xml_node(tree_out, NULL, 0, 01034 xml_node, model_dirname))) 01035 { 01036 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 01037 return JWSC_EARG(str); 01038 } 01039 01040 if ((err = train_weka_j48_model_node(*tree_out, labels, markers, 01041 weka_jar_fname))) 01042 { 01043 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 01044 return JWSC_EARG(str); 01045 } 01046 01047 return NULL; 01048 } 01049 01050 01060 Error* train_weka_part_model_tree 01061 ( 01062 Weka_model_tree** tree_out, 01063 const Vector_u32* labels, 01064 const Matrix_i32* markers, 01065 const char* tree_xml_fname, 01066 const char* tree_dtd_fname, 01067 const char* model_dirname, 01068 const char* weka_jar_fname 01069 ) 01070 { 01071 xmlDoc* xml_doc; 01072 xmlNode* xml_node; 01073 xmlNode* it; 01074 Error* err; 01075 int slen = 256; 01076 char str[slen]; 01077 01078 assert(tree_xml_fname); 01079 01080 if ((err = read_weka_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname))) 01081 { 01082 return err; 01083 } 01084 01085 for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next) 01086 { 01087 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model")) 01088 { 01089 xml_node = it; 01090 break; 01091 } 01092 } 01093 01094 if ((err = create_weka_model_tree_from_xml_node(tree_out, NULL, 0, 01095 xml_node, model_dirname))) 01096 { 01097 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 01098 return JWSC_EARG(str); 01099 } 01100 01101 if ((err = train_weka_part_model_node(*tree_out, labels, markers, 01102 weka_jar_fname))) 01103 { 01104 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 01105 return JWSC_EARG(str); 01106 } 01107 01108 return NULL; 01109 } 01110 01111 01113 static Error* recursively_predict_j48_labels_in_model_tree 01114 ( 01115 Vector_u32** labels_out, 01116 Vector_d** confs_in_out, 01117 const Weka_model_tree* tree, 01118 const Matrix_i32* markers, 01119 const char* weka_jar_fname 01120 ) 01121 { 01122 uint32_t s, num_samples; 01123 uint32_t num_markers; 01124 uint32_t n, N; 01125 uint32_t i; 01126 uint32_t subtree_label; 01127 Vector_d* confs = NULL; 01128 Vector_u32* sample_indices = NULL; 01129 Matrix_i32* markers_subset = NULL; 01130 Vector_u32* labels_subset = NULL; 01131 Vector_d* confs_subset = NULL; 01132 Error* err; 01133 01134 num_samples = markers->num_rows; 01135 num_markers = markers->num_cols; 01136 01137 assert((*confs_in_out)->num_elts == num_samples); 01138 01139 if ((err = predict_labels_with_weka_j48_model(labels_out, &confs, 01140 markers, tree->labels_fname, tree->model_fname, 01141 weka_jar_fname)) != NULL) 01142 { 01143 return err; 01144 } 01145 for (s = 0; s < num_samples; s++) 01146 { 01147 (*confs_in_out)->elts[ s ] *= confs->elts[ s ]; 01148 } 01149 01150 create_vector_u32(&sample_indices, num_samples); 01151 01152 for (i = 0; i < tree->num_groups; i++) 01153 { 01154 if (!tree->subtrees[ i ]) 01155 continue; 01156 01157 subtree_label = tree->subtrees[ i ]->parent_label; 01158 01159 N = 0; 01160 for (s = 0; s < num_samples; s++) 01161 { 01162 if ((*labels_out)->elts[ s ] == subtree_label) 01163 { 01164 sample_indices->elts[ N++ ] = s; 01165 } 01166 } 01167 01168 if (!N) 01169 continue; 01170 01171 create_matrix_i32(&markers_subset, N, num_markers); 01172 create_vector_d(&confs_subset, N); 01173 01174 for (n = 0; n < N; n++) 01175 { 01176 s = sample_indices->elts[ n ]; 01177 copy_matrix_block_into_matrix_i32(markers_subset, n, 0, 01178 markers, s, 0, 1, num_markers); 01179 confs_subset->elts[ n ] = (*confs_in_out)->elts[ s ]; 01180 } 01181 01182 if ((err = recursively_predict_j48_labels_in_model_tree( 01183 &labels_subset, &confs_subset, tree->subtrees[ i ], 01184 markers_subset, weka_jar_fname)) != NULL) 01185 { 01186 return err; 01187 } 01188 01189 for (n = 0; n < N; n++) 01190 { 01191 s = sample_indices->elts[ n ]; 01192 (*labels_out)->elts[ s ] = labels_subset->elts[ n ]; 01193 (*confs_in_out)->elts[ s ] = confs_subset->elts[ n ]; 01194 } 01195 } 01196 01197 free_vector_u32(sample_indices); 01198 free_matrix_i32(markers_subset); 01199 free_vector_u32(labels_subset); 01200 free_vector_d(confs_subset); 01201 01202 return NULL; 01203 } 01204 01205 01207 static Error* recursively_predict_part_labels_in_model_tree 01208 ( 01209 Vector_u32** labels_out, 01210 Vector_d** confs_in_out, 01211 const Weka_model_tree* tree, 01212 const Matrix_i32* markers, 01213 const char* weka_jar_fname 01214 ) 01215 { 01216 uint32_t s, num_samples; 01217 uint32_t num_markers; 01218 uint32_t n, N; 01219 uint32_t i; 01220 uint32_t subtree_label; 01221 Vector_d* confs = NULL; 01222 Vector_u32* sample_indices = NULL; 01223 Matrix_i32* markers_subset = NULL; 01224 Vector_u32* labels_subset = NULL; 01225 Vector_d* confs_subset = NULL; 01226 Error* err; 01227 01228 num_samples = markers->num_rows; 01229 num_markers = markers->num_cols; 01230 01231 assert((*confs_in_out)->num_elts == num_samples); 01232 01233 if ((err = predict_labels_with_weka_part_model(labels_out, &confs, 01234 markers, tree->labels_fname, tree->model_fname, 01235 weka_jar_fname)) != NULL) 01236 { 01237 return err; 01238 } 01239 for (s = 0; s < num_samples; s++) 01240 { 01241 (*confs_in_out)->elts[ s ] *= confs->elts[ s ]; 01242 } 01243 01244 create_vector_u32(&sample_indices, num_samples); 01245 01246 for (i = 0; i < tree->num_groups; i++) 01247 { 01248 if (!tree->subtrees[ i ]) 01249 continue; 01250 01251 subtree_label = tree->subtrees[ i ]->parent_label; 01252 01253 N = 0; 01254 for (s = 0; s < num_samples; s++) 01255 { 01256 if ((*labels_out)->elts[ s ] == subtree_label) 01257 { 01258 sample_indices->elts[ N++ ] = s; 01259 } 01260 } 01261 01262 if (!N) 01263 continue; 01264 01265 create_matrix_i32(&markers_subset, N, num_markers); 01266 create_vector_d(&confs_subset, N); 01267 01268 for (n = 0; n < N; n++) 01269 { 01270 s = sample_indices->elts[ n ]; 01271 copy_matrix_block_into_matrix_i32(markers_subset, n, 0, 01272 markers, s, 0, 1, num_markers); 01273 confs_subset->elts[ n ] = (*confs_in_out)->elts[ s ]; 01274 } 01275 01276 if ((err = recursively_predict_part_labels_in_model_tree( 01277 &labels_subset, &confs_subset, tree->subtrees[ i ], 01278 markers_subset, weka_jar_fname)) != NULL) 01279 { 01280 return err; 01281 } 01282 01283 for (n = 0; n < N; n++) 01284 { 01285 s = sample_indices->elts[ n ]; 01286 (*labels_out)->elts[ s ] = labels_subset->elts[ n ]; 01287 (*confs_in_out)->elts[ s ] = confs_subset->elts[ n ]; 01288 } 01289 } 01290 01291 free_vector_u32(sample_indices); 01292 free_matrix_i32(markers_subset); 01293 free_vector_u32(labels_subset); 01294 free_vector_d(confs_subset); 01295 01296 return NULL; 01297 } 01298 01299 01314 Error* predict_labels_with_weka_j48_model_tree 01315 ( 01316 Vector_u32** labels_out, 01317 Vector_d** confs_out, 01318 const Matrix_i32* markers, 01319 const Weka_model_tree* tree, 01320 const char* weka_jar_fname 01321 ) 01322 { 01323 Error* err; 01324 01325 create_init_vector_d(confs_out, markers->num_rows, 1.0); 01326 01327 if ((err = recursively_predict_j48_labels_in_model_tree(labels_out, 01328 confs_out, tree, markers, weka_jar_fname)) != NULL) 01329 { 01330 return err; 01331 } 01332 01333 return NULL; 01334 } 01335 01336 01351 Error* predict_labels_with_weka_part_model_tree 01352 ( 01353 Vector_u32** labels_out, 01354 Vector_d** confs_out, 01355 const Matrix_i32* markers, 01356 const Weka_model_tree* tree, 01357 const char* weka_jar_fname 01358 ) 01359 { 01360 Error* err; 01361 01362 create_init_vector_d(confs_out, markers->num_rows, 1.0); 01363 01364 if ((err = recursively_predict_part_labels_in_model_tree(labels_out, 01365 confs_out, tree, markers, weka_jar_fname)) != NULL) 01366 { 01367 return err; 01368 } 01369 01370 return NULL; 01371 } 01372 01373 01380 Error* read_weka_model_tree 01381 ( 01382 Weka_model_tree** tree_out, 01383 const char* tree_xml_fname, 01384 const char* tree_dtd_fname, 01385 const char* model_dirname 01386 ) 01387 { 01388 xmlDoc* xml_doc; 01389 xmlNode* xml_node; 01390 xmlNode* it; 01391 Error* err; 01392 int slen = 256; 01393 char str[slen]; 01394 01395 assert(tree_xml_fname && model_dirname); 01396 01397 if ((err = read_weka_xml_doc(&xml_doc, tree_xml_fname, tree_dtd_fname))) 01398 { 01399 return err; 01400 } 01401 01402 for (it = xmlDocGetRootElement(xml_doc)->children; it; it = it->next) 01403 { 01404 if (it->type == XML_ELEMENT_NODE && XMLStrEqual(it->name, "model")) 01405 { 01406 xml_node = it; 01407 break; 01408 } 01409 } 01410 01411 if ((err = create_weka_model_tree_from_xml_node(tree_out, NULL, 0, 01412 xml_node, model_dirname))) 01413 { 01414 snprintf(str, slen, "%s: %s", tree_xml_fname, err->msg); 01415 return JWSC_EARG(str); 01416 } 01417 01418 return NULL; 01419 } 01420 01421 01423 void free_weka_model_tree(Weka_model_tree* tree) 01424 { 01425 uint32_t i; 01426 01427 if (!tree) 01428 return; 01429 01430 for (i = 0; i < tree->num_groups; i++) 01431 { 01432 free_weka_model_tree(tree->subtrees[ i ]); 01433 free_vector_u32(tree->altlabels[ i ]); 01434 } 01435 01436 free(tree->subtrees); 01437 free(tree->altlabels); 01438 free_vector_u32(tree->labels); 01439 free(tree->labels_fname); 01440 free(tree->model_fname); 01441 free(tree); 01442 }