Alternaria
fit cylinders and ellipsoids to fungus
sampler.cpp
Go to the documentation of this file.
00001 /*
00002  * This work is licensed under a Creative Commons 
00003  * Attribution-Noncommercial-Share Alike 3.0 United States License.
00004  * 
00005  *    http://creativecommons.org/licenses/by-nc-sa/3.0/us/
00006  * 
00007  * You are free:
00008  * 
00009  *    to Share - to copy, distribute, display, and perform the work
00010  *    to Remix - to make derivative works
00011  * 
00012  * Under the following conditions:
00013  * 
00014  *    Attribution. You must attribute the work in the manner specified by the
00015  *    author or licensor (but not in any way that suggests that they endorse you
00016  *    or your use of the work).
00017  * 
00018  *    Noncommercial. You may not use this work for commercial purposes.
00019  * 
00020  *    Share Alike. If you alter, transform, or build upon this work, you may
00021  *    distribute the resulting work only under the same or similar license to
00022  *    this one.
00023  * 
00024  * For any reuse or distribution, you must make clear to others the license
00025  * terms of this work. The best way to do this is by including this header.
00026  * 
00027  * Any of the above conditions can be waived if you get permission from the
00028  * copyright holder.
00029  * 
00030  * Apart from the remix rights granted under this license, nothing in this
00031  * license impairs or restricts the author's moral rights.
00032  */
00033 
00034 
00047 #include <config.h>
00048 
00049 #include <iostream>
00050 #include <iomanip>
00051 #include <fstream>
00052 #include <sstream>
00053 #include <list>
00054 
00055 #include <inttypes.h>
00056 #include <sys/time.h>
00057 
00058 #include <jwsc/prob/pdf.h>
00059 #include <jwsc/matblock/matblock.h>
00060 
00061 #include <jwsc++/base/exception.h>
00062 
00063 #include "alternaria_model.h"
00064 #include "psf_model.h"
00065 #include "imaging_model.h"
00066 #include "sampler.h"
00067 
00068 
00069 using std::list;
00070 using std::ostream;
00071 using std::ofstream;
00072 using std::ostringstream;
00073 using namespace jwsc;
00074 using jwscxx::base::Exception;
00075 using jwscxx::base::Arg_error;
00076 using jwscxx::base::IO_error;
00077 
00078 
00087 Sampler_move_parameters::Sampler_move_parameters
00088 (
00089     Alternaria_model*       alternaria,
00090     PSF_model*              psf,
00091     Imaging_model*          imaging,
00092     const jwsc::Matblock_f* data
00093 )
00094 throw (Arg_error)
00095 {
00096     this->alternaria = alternaria;
00097     this->psf        = psf;
00098     this->imaging    = imaging;
00099     this->data       = data;
00100 
00101     data_avg = 0;
00102 
00103     uint32_t num_elts = data->num_mats*data->num_rows*data->num_cols;
00104     float*   elts     = **(data->elts);
00105 
00106     for (uint32_t i = 0; i < num_elts; i++)
00107     {
00108         data_avg += elts[i];
00109     }
00110     data_avg /= (float)num_elts;
00111 
00112     ll = alternaria->calc_log_likelihood(psf, imaging, data, data_avg);
00113 
00114     alternaria_proposal = alternaria->clone();
00115     psf_proposal        = psf->clone();
00116     imaging_proposal    = imaging->clone();
00117     ll_proposal         = ll;
00118 }
00119 
00120 
00122 Sampler_move_parameters::~Sampler_move_parameters()
00123 {
00124     delete alternaria;
00125     delete psf;
00126     delete imaging;
00127     delete alternaria_proposal;
00128     delete psf_proposal;
00129     delete imaging_proposal;
00130 }
00131 
00132 
00134 void Sampler_move_parameters::init_proposals()
00135 {
00136     *alternaria_proposal = *alternaria;
00137     *psf_proposal        = *psf;
00138     *imaging_proposal    = *imaging;
00139     ll_proposal          = ll;
00140 }
00141 
00142 
00152 Sampler_diffusion_move::Sampler_diffusion_move(const char* name, double prob) 
00153     throw (Arg_error)
00154 {
00155     num_accepted = 0;
00156     num_attempts = 0;
00157     this->name   = name;
00158     set_prob(prob);
00159 }
00160 
00161 
00167 void Sampler_diffusion_move::set_prob(double prob) throw (Arg_error)
00168 {
00169     if (0 > prob || prob > 1.0)
00170     {
00171         throw Arg_error("Sampler move probability not in [0,1]");
00172     }
00173 
00174     this->prob = prob;
00175 }
00176 
00177 
00189 Sampler_jump_move::Sampler_jump_move
00190 (
00191     const char* name_1, 
00192     const char* name_2,
00193     double      prob_1,
00194     double      prob_2
00195 ) 
00196 throw (Arg_error)
00197 {
00198     num_accepted_1 = 0;
00199     num_attempts_1 = 0;
00200     this->name_1   = name_1;
00201     set_prob_1(prob_1);
00202 
00203     num_accepted_2 = 0;
00204     num_attempts_2 = 0;
00205     this->name_2   = name_2;
00206     set_prob_2(prob_2);
00207 
00208     name = name_1;
00209 }
00210 
00211 
00217 bool Sampler_jump_move::run(Sampler_move_parameters* params) throw (Exception)
00218 {
00219     if (sample_uniform_pdf_d(0, prob_1+prob_2) <= prob_1)
00220     {
00221         name = name_1;
00222         return run_1(params);
00223     }
00224     name = name_2;
00225     return run_2(params);
00226 }
00227 
00228 
00234 const std::ostringstream& Sampler_jump_move::get_info
00235 (
00236     Sampler_move_parameters* params
00237 ) 
00238 {
00239     assert(params->alternaria);
00240     info.str("");
00241     info << std::left << std::setw(4) 
00242          << params->alternaria->get_num_levels();
00243     info << std::left << std::setw(4) 
00244          << params->alternaria->get_num_apical_hypha();
00245     info << std::left << std::setw(4) 
00246          << params->alternaria->get_num_lateral_hypha();
00247     info << std::left << std::setw(4) 
00248          << params->alternaria->get_num_spores();
00249     return info;
00250 }
00251 
00252 
00258 const std::ostringstream& Sampler_jump_move::get_proposal_info
00259 (
00260     Sampler_move_parameters* params
00261 )
00262 { 
00263     assert(params->alternaria_proposal);
00264     proposal_info.str("");
00265     proposal_info << std::left << std::setw(4)
00266                   << params->alternaria_proposal->get_num_levels();
00267     proposal_info << std::left << std::setw(4) 
00268                   << params->alternaria_proposal->get_num_apical_hypha();
00269     proposal_info << std::left << std::setw(4)
00270                   << params->alternaria_proposal->get_num_lateral_hypha();
00271     proposal_info << std::left << std::setw(4)
00272                   << params->alternaria_proposal->get_num_spores();
00273     return proposal_info; 
00274 }
00275 
00276 
00282 void Sampler_jump_move::set_prob(double prob) throw (Arg_error)
00283 {
00284     if (0 > prob || prob > 1.0)
00285     {
00286         throw Arg_error("Sampler move probability not in [0,1]");
00287     }
00288 
00289     if ((prob_1 + prob_2) == 0 || prob == 0)
00290     {
00291         set_prob_1(0);
00292         set_prob_2(0);
00293     }
00294 
00295     double d = (prob_1 + prob_2) / prob;
00296 
00297     set_prob_1(prob_1/d);
00298     set_prob_2(prob_2/d);
00299 }
00300 
00301 
00307 void Sampler_jump_move::set_prob_1(double prob_1) throw (Arg_error)
00308 {
00309     if (0 > prob_1 || prob_1 > 1.0)
00310     {
00311         throw Arg_error("Sampler move probability not in [0,1]");
00312     }
00313 
00314     this->prob_1 = prob_1;
00315 }
00316 
00317 
00323 void Sampler_jump_move::set_prob_2(double prob_2) throw (Arg_error)
00324 {
00325     if (0 > prob_2 || prob_2 > 1.0)
00326     {
00327         throw Arg_error("Sampler move probability not in [0,1]");
00328     }
00329 
00330     this->prob_2 = prob_2;
00331 }
00332 
00333 
00335 Sampler::~Sampler()
00336 {
00337     for (list<Sampler_move*>::iterator it = moves.begin(); 
00338             it != moves.end(); it++)
00339     {
00340         delete *it;
00341     }
00342 }
00343 
00344 
00346 void Sampler::add_move(Sampler_move* move)
00347 {
00348     moves.push_back(move);
00349 }
00350 
00351 
00353 void Sampler::normalize_move_probs() throw (Arg_error)
00354 {
00355     double sum = 0;
00356 
00357     for (list<Sampler_move*>::iterator it = moves.begin(); 
00358             it != moves.end(); it++)
00359     {
00360         sum += (*it)->get_prob();
00361     }
00362 
00363     if (sum <= 1.0e-10)
00364     {
00365         throw Arg_error("Sum of move probabilities is zero");
00366     }
00367 
00368     for (list<Sampler_move*>::iterator it = moves.begin(); 
00369             it != moves.end(); it++)
00370     {
00371         (*it)->set_prob((*it)->get_prob() / sum);
00372     }
00373 }
00374 
00375 
00409 void Sampler::run
00410 (
00411     Alternaria_model**       best_alternaria,
00412     PSF_model**              best_psf,
00413     Imaging_model**          best_imaging,
00414     Sampler_move_parameters* params,
00415     uint32_t                 iterations,
00416     const char*              move_fname,
00417     const char*              best_fname,
00418     const char*              alt_pro_fmt
00419 ) 
00420 throw (Exception)
00421 {
00422     ostringstream ost;
00423     ofstream move_out;
00424     ofstream best_out;
00425 
00426     move_out.open(move_fname);
00427     if (move_out.fail())
00428     {
00429         ost << move_fname << ": Could not open file";
00430         throw IO_error(ost.str());
00431     }
00432 
00433     best_out.open(best_fname);
00434     if (best_out.fail())
00435     {
00436         ost << best_fname << ": Could not open file";
00437         throw IO_error(ost.str());
00438     }
00439 
00440     run(best_alternaria, best_psf, best_imaging, params, iterations, 
00441             move_out, best_out, alt_pro_fmt);
00442 
00443     move_out.close();
00444     if (move_out.fail())
00445     {
00446         ost << move_fname << ": Could not close file";
00447         throw IO_error(ost.str());
00448     }
00449 
00450     best_out.close();
00451     if (best_out.fail())
00452     {
00453         ost << best_fname << ": Could not close file";
00454         throw IO_error(ost.str());
00455     }
00456 }
00457 
00458 
00492 void Sampler::run
00493 (
00494     Alternaria_model**       best_alternaria,
00495     PSF_model**              best_psf,
00496     Imaging_model**          best_imaging,
00497     Sampler_move_parameters* params,
00498     uint32_t                 iterations,
00499     std::ostream&            move_out,
00500     std::ostream&            best_out,
00501     const char*              alt_pro_fmt
00502 ) 
00503 throw (Exception)
00504 {
00505     assert(*best_alternaria == 0);
00506     assert(*best_psf == 0);
00507     assert(*best_imaging == 0);
00508 
00509     *best_alternaria = params->alternaria->clone();
00510     *best_psf        = params->psf->clone();
00511     *best_imaging    = params->imaging->clone();
00512 
00513     double ll_best = params->ll;
00514     double lp_best = ll_best;
00515     lp_best += params->alternaria->get_log_prob();
00516     lp_best += params->psf->get_log_prob();
00517     lp_best += params->imaging->get_log_prob();
00518 
00519     double lp = params->ll;
00520     lp += params->alternaria->get_log_prob();
00521     lp += params->psf->get_log_prob();
00522     lp += params->imaging->get_log_prob();
00523 
00524     list<Sampler_move*>::iterator it;
00525     list<Sampler_move*>::iterator itt;
00526 
00527     // Print some inital state information for the moves.
00528     best_out << std::left << std::setw(7)  << 0
00529              << std::left << std::setw(14) << ll_best
00530              << std::left << std::setw(14) << lp_best;
00531     for (it = moves.begin(); it != moves.end(); it++)
00532     {
00533         Sampler_jump_move* jump_move;
00534 
00535         if ((jump_move = dynamic_cast<Sampler_jump_move*>(*it)))
00536         {
00537             move_out << std::left << std::setw(7)  << 0
00538                      << std::left << std::setw(12)  << jump_move->get_name_1()
00539                      << std::left << std::setw(8)  << "Init"
00540                      << std::left << std::setw(14) << params->ll
00541                      << std::left << std::setw(14) << lp
00542                      << jump_move->get_info(params).str() << '\n';
00543             move_out << std::left << std::setw(7)  << 0
00544                      << std::left << std::setw(12)  << jump_move->get_name_2()
00545                      << std::left << std::setw(8)  << "Init"
00546                      << std::left << std::setw(14) << params->ll
00547                      << std::left << std::setw(14) << lp
00548                      << jump_move->get_info(params).str() << '\n';
00549         }
00550         else
00551         {
00552             move_out << std::left << std::setw(7)  << 0
00553                      << std::left << std::setw(12)  << (*it)->get_name()
00554                      << std::left << std::setw(8)  << "Init"
00555                      << std::left << std::setw(14) << params->ll
00556                      << std::left << std::setw(14) << lp
00557                      << (*it)->get_info(params).str() << '\n';
00558         }
00559         move_out.flush();
00560 
00561         best_out << (*it)->get_info(params).str();
00562     }
00563     best_out << '\n';
00564     best_out.flush();
00565 
00566     struct timeval tv1;
00567     struct timeval tv2;
00568 
00569     // Iteratively generate model samples, keeping track of the best models.
00570     for (uint32_t i = 1; i <= iterations; i++)
00571     {
00572         double u = sample_uniform_pdf_d(0, 1);
00573         double p = 0;
00574 
00575         move_out << std::left << std::setw(7) << i;
00576 
00577         // Execute one of the sampler moves.
00578         for (it = moves.begin(); it != moves.end(); it++)
00579         {
00580             Sampler_move* move = *it;
00581 
00582             p += move->get_prob();
00583             if (u <= p)
00584             {
00585                 bool accept;
00586 
00587                 try
00588                 {
00589                     assert(gettimeofday(&tv1, NULL) == 0);
00590                     accept = move->run(params);
00591                     assert(gettimeofday(&tv2, NULL) == 0);
00592                 }
00593                 catch (Exception e)
00594                 {
00595                     try
00596                     {
00597                         params->alternaria->get_root()->recursively_check_structure();
00598                     }
00599                     catch (Exception ee)
00600                     {
00601                         // This should never happen!
00602                         std::cerr << "alternaria: sampler.cpp: " 
00603                                   << ee.get_msg() << '\n';
00604                         abort();
00605                     }
00606 
00607                     try
00608                     {
00609                         params->alternaria_proposal->get_root()->recursively_check_structure();
00610                     }
00611                     catch (Exception ee)
00612                     {
00613                         accept = false;
00614                     }
00615                 }
00616 
00617                 double lp_proposal = params->ll_proposal;
00618                 lp_proposal += params->alternaria_proposal->get_log_prob();
00619                 lp_proposal += params->psf_proposal->get_log_prob();
00620                 lp_proposal += params->imaging_proposal->get_log_prob();
00621 
00622                 move_out << std::left << std::setw(12) << move->get_name();
00623 
00624                 if (accept)
00625                 {
00626                     move_out << std::left << std::setw(8) << "Accept";
00627 
00628                     *(params->alternaria) = *(params->alternaria_proposal);
00629                     *(params->psf)        = *(params->psf_proposal);
00630                     *(params->imaging)    = *(params->imaging_proposal);
00631 
00632                     params->ll = params->ll_proposal;
00633 
00634                     // Update the best models if necessary.
00635                     if (lp_proposal > lp_best)
00636                     {
00637                         **best_alternaria = *(params->alternaria_proposal);
00638                         **best_psf        = *(params->psf_proposal);
00639                         **best_imaging    = *(params->imaging_proposal);
00640                         ll_best           = params->ll_proposal;
00641                         lp_best           = lp_proposal;
00642 
00643                         best_out << std::left << std::setw(7) << i
00644                                  << std::left << std::setw(14) << ll_best
00645                                  << std::left << std::setw(14) << lp_best;
00646                         for (itt = moves.begin(); itt != moves.end(); itt++)
00647                         {
00648                             best_out << (*itt)->get_info(params).str();
00649                         }
00650                         best_out << '\n';
00651                         best_out.flush();
00652                     }
00653                 }
00654                 else
00655                 {
00656                     move_out << std::left << std::setw(8) << "Reject";
00657                 }
00658 
00659                 double seconds = (tv2.tv_sec - tv1.tv_sec) + 
00660                          (tv2.tv_usec - tv1.tv_usec)*1.0e-6;
00661 
00662                 move_out << std::left << std::setw(14) << params->ll_proposal 
00663                          << std::left << std::setw(14) << lp_proposal
00664                          << std::left << std::setprecision(4) << std::setw(10)
00665                          << seconds
00666                          << move->get_proposal_info(params).str() << '\n';
00667                 move_out.flush();
00668 
00669                 try
00670                 {
00671                   params->alternaria->get_root()->recursively_check_structure();
00672                 }
00673                 catch (Exception e)
00674                 {
00675                     std::cerr << "alternaria: sampler.cpp: " 
00676                               << e.get_msg() << '\n';
00677                     abort();
00678                 }
00679 
00680                 break;
00681             }
00682         }
00683 
00684         if (alt_pro_fmt)
00685         {
00686             char buf[256] = {0};
00687             sprintf(buf, alt_pro_fmt, i);
00688             params->alternaria_proposal->print(buf);
00689         }
00690     }
00691 }