Go to the documentation of this file.00001 #include <RooGlobalFunc.h>
00002 #include <RooMsgService.h>
00003 #include <RooProdPdf.h>
00004 #include <RooWorkspace.h>
00005 #include <TFile.h>
00006
00007 #include "../../BAT/BCMath.h"
00008 #include <iostream>
00009
00010 #include "BCRooInterface.h"
00011
00012 #include <RooUniform.h>
00013
00014
00015 #include "RooRealVar.h"
00016 #include "RooAbsReal.h"
00017 #include "RooRandom.h"
00018
00019
00020
00021
00022
00023 void BCRooInterface::Initialize( RooAbsData& data,
00024 RooAbsPdf& model,
00025 RooAbsPdf& prior_trans,
00026 const RooArgSet* params,
00027 const RooArgSet& listPOI )
00028 {
00029
00030 fData = &data;
00031 fModel = &model;
00032
00033
00034 RooAbsPdf* prior_total = &prior_trans;
00035 if (prior_total!=0 ) {
00036 fPrior = prior_total;
00037 }
00038 else {
00039 std::cout << "No prior PDF: without taking action the program would crash\n";
00040 std::cout << "No prior PDF: adding dummy uniform pdf on the interval [0..1]\n";
00041 priorhelpvar = new RooRealVar("priorhelpvar","",0.0, 1.0 );
00042 _addeddummyprior = true;
00043 RooUniform* priorhelpdist = new RooUniform("priorhelpdist","", *priorhelpvar);
00044 fPrior = priorhelpdist;
00045 }
00046
00047 std::cout << "Imported parameters:\n";
00048 fParams = new RooArgList(listPOI);
00049 const RooArgSet* paramsTmp = params;
00050 if (paramsTmp!=0)
00051 fParams->add(*paramsTmp);
00052 fParams->Print("v");
00053
00054 fParamsPOI = new RooArgList(listPOI);
00055
00056
00057
00058 RooArgSet* constrainedParams = fModel->getParameters(*fData);
00059 fNll = fModel->createNLL(*fData, RooFit::Constrain(*constrainedParams) );
00060
00061 DefineParameters();
00062
00063 if(_fillChain) {
00064 SetupRooStatsMarkovChain();
00065 }
00066 }
00067
00068
00069
00070 void BCRooInterface::Initialize( const char* rootFile,
00071 const char* wsName,
00072 const char* dataName,
00073 const char* modelName,
00074 const char* priorName,
00075 const char* priorNuisanceName,
00076 const char* paramsName,
00077 const char* listPOIName )
00078 {
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093 std::cout << "Opening " << rootFile << std::endl;
00094 TFile* file = new TFile(rootFile);
00095 std::cout << "content :\n";
00096 file->ls();
00097
00098 RooWorkspace* bat_ws = (RooWorkspace*) file->Get(wsName);
00099 bat_ws->Print("v");
00100
00101 fData = (RooAbsData*) bat_ws->data(dataName);
00102 fModel = (RooAbsPdf*) bat_ws->function(modelName);
00103
00104
00105 RooAbsPdf* priorPOI = (RooAbsPdf*) bat_ws->function(priorName);
00106 RooAbsPdf* priorNuisance = (RooAbsPdf*) bat_ws->pdf(priorNuisanceName);
00107 if (priorNuisance!=0 && priorPOI!=0) {
00108 fPrior = new RooProdPdf("fPrior","complete prior",*priorPOI,*priorNuisance);
00109 }
00110 else {
00111 if ( priorNuisance!=0 )
00112 fPrior=priorNuisance;
00113 else if ( priorPOI!=0 )
00114 fPrior = priorPOI;
00115 else{
00116 std::cout << "No prior PDF: without taking action the program would crash\n";
00117 std::cout << "No prior PDF: adding dummy uniform pdf on the interval [0..1]\n";
00118 priorhelpvar = new RooRealVar("priorhelpvar","",0.0, 1.0 );
00119 _addeddummyprior = true;
00120 RooUniform* priorhelpdist = new RooUniform("priorhelpdist","", *priorhelpvar);
00121 fPrior = priorhelpdist;
00122 }
00123 }
00124
00125 std::cout << "Imported parameters:\n";
00126 fParams = new RooArgList(*(bat_ws->set(listPOIName)));
00127 RooArgSet* paramsTmp = (RooArgSet*) bat_ws->set(paramsName);
00128 if (paramsTmp!=0) {
00129 fParams->add(*paramsTmp);
00130 }
00131 if (_addeddummyprior == true ) {
00132 fParams->add(*priorhelpvar);
00133 }
00134 fParams->Print("v");
00135
00136
00137
00138 RooArgSet* constrainedParams = fModel->getParameters(*fData);
00139 fNll = fModel->createNLL(*fData, RooFit::Constrain(*constrainedParams) );
00140
00141 file->Close();
00142
00143 DefineParameters();
00144 }
00145
00146
00147
00148 BCRooInterface::BCRooInterface() : BCModel()
00149 {
00150 _default_nbins = 500;
00151 _fillChain = false;
00152 _addeddummyprior = false;
00153 }
00154
00155
00156 BCRooInterface::BCRooInterface(const char* name, bool fillChain) : BCModel(name)
00157 {
00158 _default_nbins = 500;
00159 _fillChain = fillChain;
00160 }
00161
00162
00163 BCRooInterface::~BCRooInterface()
00164 {
00165 if(_fillChain) {
00166 delete _roostatsMarkovChain;
00167 }
00168
00169
00170
00171 }
00172
00173
00174 void BCRooInterface::DefineParameters()
00175 {
00176
00177 int nParams = fParams->getSize();
00178 for (int iParam=0; iParam<nParams; iParam++) {
00179 RooRealVar* ipar = (RooRealVar*) fParams->at(iParam);
00180 this->AddParameter(ipar->GetName(),ipar->getMin(),ipar->getMax());
00181 this->SetNbins(ipar->GetName(),_default_nbins);
00182 std::cout << "added parameter: " << ipar->GetName() << " defined in range [ " << ipar->getMin() << " - " << ipar->getMax() << " ] with number of bins: " << _default_nbins << " \n";
00183 }
00184
00185 for(std::list< std::pair<const char*,int> >::iterator listiter = _nbins_list.begin(); listiter != _nbins_list.end(); listiter++) {
00186 this->SetNbins((*listiter).first,(*listiter).second);
00187 std::cout << "adjusted parameter: " << (*listiter).first << " to number of bins: " << (*listiter).second << " \n";
00188 }
00189
00190 }
00191
00192
00193 double BCRooInterface::LogLikelihood(const std::vector<double> & parameters)
00194 {
00195
00196
00197 int nParams = fParams->getSize();
00198 for (int iParam=0; iParam<nParams; iParam++) {
00199 RooRealVar* ipar = (RooRealVar*) fParams->at(iParam);
00200 ipar->setVal(parameters.at(iParam));
00201 }
00202
00203
00204 double logprob = -fNll->getVal();
00205 return logprob;
00206 }
00207
00208
00209 double BCRooInterface::LogAPrioriProbability(const std::vector<double> & parameters)
00210 {
00211
00212 int nParams = fParams->getSize();
00213 for (int iParam=0; iParam<nParams; iParam++) {
00214 RooRealVar* ipar = (RooRealVar*) fParams->at(iParam);
00215 ipar->setVal(parameters.at(iParam));
00216 }
00217
00218 RooArgSet* tmpArgSet = new RooArgSet(*fParams);
00219 double prob = fPrior->getVal(tmpArgSet);
00220 delete tmpArgSet;
00221 if (prob<1e-300)
00222 prob = 1e-300;
00223 return log(prob);
00224 }
00225
00226 void BCRooInterface::SetNumBins(const char * parname, int nbins)
00227 {
00228 for(std::list< std::pair<const char*,int> >::iterator listiter = _nbins_list.begin(); listiter != _nbins_list.end(); listiter++) {
00229 if(!strcmp((*listiter).first, parname)) {
00230 (*listiter).second = nbins;
00231 return;
00232 }
00233 }
00234 _nbins_list.push_back( std::make_pair(parname,nbins) );
00235 }
00236
00237 void BCRooInterface::SetNumBins(int nbins)
00238 {
00239 _default_nbins = nbins;
00240 }
00241
00242 void BCRooInterface::SetupRooStatsMarkovChain()
00243 {
00244
00245
00246
00247
00248
00249
00250
00251
00252 _parametersForMarkovChainPrevious.add(*fParams);
00253 _parametersForMarkovChainCurrent.add(*fParams);
00254
00255 std::cout << "size of _parametersForMarkovChain: " << _parametersForMarkovChainCurrent.getSize() << std::endl;
00256 std::cout << "size of fParamsPOI: " << fParamsPOI->getSize() << std::endl;
00257
00258
00259 _roostatsMarkovChain = new RooStats::MarkovChain();
00260
00261
00262
00263
00264 std::cout << "setting up parameters for RooStats markov chain" << std::endl;
00265 _parametersForMarkovChainPrevious.writeToStream(std::cout, false);
00266
00267
00268 int nchains = MCMCGetNChains();
00269 for(int countChains = 1; countChains<=nchains ; countChains++ ) {
00270 double tempweight = 1.0;
00271 fVecWeights.push_back(tempweight);
00272 std::vector<double> tempvec;
00273 TIterator* setiter = fParamsPOI->createIterator();
00274 double tempval = 0.;
00275 while(setiter->Next()){
00276 tempvec.push_back(tempval);
00277 }
00278 fPreviousStep.push_back(tempvec);
00279 fCurrentStep.push_back(tempvec);
00280 }
00281
00282 fFirstComparison = true;
00283
00284
00285
00286
00287
00288
00289
00290
00291 }
00292
00293
00294 void BCRooInterface::MCMCIterationInterface()
00295 {
00296
00297
00298 if(_fillChain) {
00299
00300
00301 int nchains = MCMCGetNChains();
00302
00303
00304 int npar = GetNParameters();
00305
00306
00307
00308 for (int i = 0; i < nchains; ++i) {
00309
00310
00311
00312
00313 TIterator* setiter = fParams->createIterator();
00314 int j = 0;
00315
00316
00317
00318
00319
00320 while(setiter->Next()){
00321
00322
00323 BCParameter * tempBCparam = GetParameter(j);
00324
00325
00326
00327 const char * paramnamepointer = (tempBCparam->GetName()).c_str();
00328 double xij = fMCMCx.at(i * npar + j);
00329 AddToCurrentChainElement(xij, i, j);
00330 RooRealVar* parampointer = (RooRealVar*) &(_parametersForMarkovChainCurrent[paramnamepointer]);
00331 parampointer->setVal(xij);
00332
00333 j++;
00334 }
00335
00336
00337
00338
00339
00340
00341
00342
00343
00344
00345
00346 if( !(EqualsLastChainElement(i)) ) {
00347 double weight = GetWeightForChain(i);
00348 _roostatsMarkovChain->Add(_parametersForMarkovChainPrevious, -1.* MCMCGetLogProbx(j), weight);
00349 _parametersForMarkovChainPrevious = _parametersForMarkovChainCurrent;
00350 }
00351 }
00352 }
00353 }
00354
00355 void BCRooInterface::AddToCurrentChainElement(double xij, int chainNum, int poiNum)
00356 {
00357 fCurrentStep[chainNum][poiNum] = xij;
00358 }
00359
00360 bool BCRooInterface::EqualsLastChainElement(int chainNum)
00361 {
00362 bool equals = true;
00363 std::vector<double>::iterator itPrevious = fPreviousStep[chainNum].begin();
00364
00365 if(fFirstComparison == true) {
00366 fFirstComparison = false;
00367 _parametersForMarkovChainPrevious = _parametersForMarkovChainCurrent;
00368 return true;
00369 }
00370
00371
00372 for (std::vector<double>::iterator itCurrent = fCurrentStep[chainNum].begin(); itCurrent != fCurrentStep[chainNum].end(); ++itCurrent) {
00373 if(*itCurrent != *itPrevious) {
00374 equals = false;
00375 fPreviousStep[chainNum] = fCurrentStep[chainNum];
00376 break;
00377 }
00378 ++itPrevious;
00379 }
00380
00381 if(equals == true) {
00382 fVecWeights[chainNum] += 1.0;
00383 }
00384
00385 return equals;
00386
00387 }
00388
00389 double BCRooInterface::GetWeightForChain(int chainNum)
00390 {
00391 double retval = fVecWeights[chainNum];
00392 fVecWeights[chainNum]= 1.0 ;
00393 return retval;
00394 }
00395