package blog;

import java.util.*;
import java.io.*;
import common.Histogram;
import common.Util;
import common.UnaryPredicate;

import junit.framework.TestCase;

/**
 * Unit testing for the {@link ParticleFilter}.
 * Because sampling can potentially fail no matter the error margin,
 * tests sometimes fail. This should be rare, however.
 * If so, the user can check the indicated error to see if things look ok,
 * or run the test again.
 */
public class ParticleFilterTest extends TestCase {

    // Configuration:
    private double delta = 0.05; // the allowed difference between
                                 // expected and computed values

    public static void main(String[] args) throws Exception {
	Util.initRandom(true);
	junit.textui.TestRunner.run(ParticleFilterTest.class);
    }

    /** Sets particle filter properties to default values before every test. */
    public void setUp() {
	setDefaultParticleFilterProperties();
    }

    private void setDefaultParticleFilterProperties() {
	properties = new Properties();
	properties.setProperty("numParticles", "5000");
	properties.setProperty("useDecayedMCMC", "true");
	properties.setProperty("numMoves", "1");
    }

    private static final String weatherModelString =
	"type RainEvent;" +
	"guaranteed RainEvent Rainy, Dry;" +

	"random RainEvent Weather(Timestep);" +
	"random RainEvent RainyRegion();" +

	"RainyRegion ~ TabularCPD[[0.5, 0.5]]();" +

	"Weather(d) " +
	" 	if (d = @0) then ~ TabularCPD[[0.7, 0.3],[0.3, 0.7]](RainyRegion)" +
	"	else ~ TabularCPD[[0.8 , 0.2]," +
	"	                  [0.3, 0.7]," +
	"	                  [0.6 , 0.4]," +
	"	                  [0.2, 0.8]]" +
	"	             (RainyRegion, Weather(Prev(d)));";

    public void test1() throws Exception {
	setModel(weatherModelString);

	assertProb("obs Weather(@0)=Rainy; query Weather(@0);", "Rainy", 1);
	assertProb("obs Weather(@1)=Rainy; query Weather(@1);", "Rainy", 1);
	assertProb("obs Weather(@2)=Rainy; query Weather(@2);", "Rainy", 1);
	assertProb("query Weather(@3);", "Rainy", 0.7611510791366907); // calculated by exact inference
	assertProb("query RainyRegion;", "Rainy", 0.8057553956834532);
    }

    public void test2() throws Exception {
	setModel(weatherModelString);

	assertProb("query Weather(@3);", "Rainy", 0.47185);
	assertProb("query RainyRegion;", "Rainy", 0.5);
    }

    public void testLongerInterval() throws Exception {
	setModel(weatherModelString);

	properties.setProperty("numParticles", "1000");

	assertProb("query Weather(@15);", "Rainy", 0.45);
	assertProb("query RainyRegion;", "Rainy", 0.5);
	
	setDefaultParticleFilterProperties();
    }

    private void setModel(String newModelString) throws Exception {
	model = new Model();
	Main.stringSetup(model, new Evidence(), new LinkedList(), newModelString);
	engine = new ParticleFilter(model, properties);
    }

    private void assertProb(String evidenceAndQuery, String valueString, double expected) throws Exception {
	BLOGParser.ModelEvidenceQueries meq = BLOGParser.parseString(model, evidenceAndQuery);
	engine.take(meq.evidence);
	engine.answer(meq.queries);
	assertEquals(expected, getProbabilityByString(getQuery(meq.queries), valueString), delta);
	outputQueries(meq.queries);
    }

    private void outputQueries(Collection queries) {
	for(Iterator it = queries.iterator(); it.hasNext();) {
	    ArgSpecQuery query = (ArgSpecQuery) it.next();
	    for(Iterator it2 = query.getHistogram().entrySet().iterator(); it2.hasNext();) {
		Histogram.Entry entry = (Histogram.Entry) it2.next();
		double prob = entry.getWeight() / query.getHistogram().getTotalWeight();
		System.out.println("Prob. of " + query + " = " + entry.getElement() + " is " + prob);
	    }
	}
    }

    /** Helper function to get the probability of a value in an answered query
     * from its string.
     */
    private static double getProbabilityByString(ArgSpecQuery query, String valueString) {
	Histogram hist = query.getHistogram();
	Object value = model.getConstantValue(valueString);
	return hist.getWeight(value) / hist.getTotalWeight();
    }

    /** Helper function that gets a collection assumed to contain a single query and returns that query. */
    private ArgSpecQuery getQuery(Collection singleton) {
	return (ArgSpecQuery ) Util.getFirst(singleton);
    }

    private static Properties properties; 
    private static ParticleFilter engine;
    private static Model model;
}
