ArgSpecQuery.java

/*
 * Copyright (c) 2005, 2006, Regents of the University of California
 * All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 * * Redistributions of source code must retain the above copyright
 *   notice, this list of conditions and the following disclaimer.
 *
 * * Redistributions in binary form must reproduce the above copyright
 *   notice, this list of conditions and the following disclaimer in
 *   the documentation and/or other materials provided with the
 *   distribution.  
 *
 * * Neither the name of the University of California, Berkeley nor
 *   the names of its contributors may be used to endorse or promote
 *   products derived from this software without specific prior 
 *   written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
 * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
 * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
 * FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
 * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
 * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
 * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
 * OF THE POSSIBILITY OF SUCH DAMAGE.
 */

package blog;

import java.util.*;
import java.io.PrintStream;

import common.Histogram;
import ve.Factor;
import ve.Potential;

public class ArgSpecQuery extends AbstractQuery {

    public ArgSpecQuery( ArgSpec argSpec ){
	this.argSpec = argSpec;

	if (Main.histOut() != null) {
	    outputFile = Main.filePrintStream(Main.histOut() + "-trial" +
					      + trialNum + ".data");
	}
    }

    public ArgSpec argSpec() {
	return argSpec;
    }

    public void printResults(PrintStream s) {
	s.println("Distribution of values for " + argSpec);
	List entries = new ArrayList(histogram.entrySet());

	if (argSpec.isNumeric()) {
	    // Sort in numerical order by element
	    Collections.sort(entries, new Comparator() {
		    public int compare(Object o1, Object o2) {
			Object e1 = ((Histogram.Entry) o1).getElement();
			Object e2 = ((Histogram.Entry) o2).getElement();
			double n1 = ((Number) e1).doubleValue();
			double n2 = ((Number) e2).doubleValue();
			if (n1 < n2) {
			    return -1;
			} else if (n1 > n2) {
			    return 1;
			}
			return 0;
		    }
		});
	} else {
	    // Sort from most to least frequent (reverse of usual order)
	    Collections.sort(entries, new Comparator() {
		    public int compare(Object o1, Object o2) {
			double diff = (((Histogram.Entry) o1).getWeight()
				       - ((Histogram.Entry) o2).getWeight());
			if (diff < 0) {
			    return 1;
			} else if (diff > 0) {
			    return -1;
			}
			return 0;
		    }
		});
	}

	for (Iterator iter = entries.iterator(); iter.hasNext(); ) {
	    Histogram.Entry entry = (Histogram.Entry) iter.next();
	    double prob = entry.getWeight() / histogram.getTotalWeight();
	    s.println("\t" + prob + "\t" + entry.getElement());
	}
    }

    public void logResults(int numSamples) {
	final List entries = new ArrayList(histogram.entrySet());
	for (Iterator iter = entries.iterator(); iter.hasNext(); ) {
	    Histogram.Entry entry = (Histogram.Entry) iter.next();
	    double prob = entry.getWeight() / histogram.getTotalWeight();
	    PrintStream s = getOutputFile(entry.getElement());
	    s.println("\t" + numSamples + "\t" + prob);
	}

	if ((numSamples == Main.numSamples()) && (Main.histOut() != null)) {
	    Comparator c = new Comparator() {
		    public int compare (Object o1, Object o2) {
			Integer i1 =
			    new Integer(((Histogram.Entry) o1)
					.getElement().toString());
			Integer i2 =
			    new Integer(((Histogram.Entry) o2)
					.getElement().toString());
			return i1.compareTo(i2);
		    }
		};
	    Collections.sort(entries, c);
	    for (Iterator iter = entries.iterator(); iter.hasNext(); ) {
		Histogram.Entry entry = (Histogram.Entry) iter.next();
		double prob = entry.getWeight() / histogram.getTotalWeight();
		outputFile.println("\t" + entry.getElement() + "\t" + prob);
	    }
	}
    }

    public Collection<? extends BayesNetVar> getVariables() {
	if (variable == null) {
	    throw new IllegalStateException
		("Query has not yet been compiled.");
	}
	return Collections.singleton(variable);
    }

    public boolean checkTypesAndScope(Model model) {
	if (argSpec instanceof Term) {
	    Term termInScope = ((Term) argSpec)
		.getTermInScope(model, Collections.EMPTY_MAP);
	    if (termInScope == null) {
		return false;
	    }
	    argSpec = termInScope;
	    return true;
	} 
	return argSpec.checkTypesAndScope(model, Collections.EMPTY_MAP);
    }

    

Compiles the underlying ArgSpec, and initializes the variable corresponding to this query.


    public int compile() {
	int errors = argSpec.compile(new LinkedHashSet());
	if (errors == 0) {
	    variable = argSpec.getVariable();
	}
	return errors;
    }

    public void updateStats(PartialWorld world, double weight) {
	Object value = argSpec.evaluate(world);
	histogram.increaseWeight(value, weight);
    }

    public void setPosterior(Factor posterior) {
	if (!posterior.getRandomVars().contains((BasicVar) variable)) {
	    throw new IllegalArgumentException
		("Query variable " + variable + " not covered by factor on "
		 + posterior.getRandomVars());
	}
	if (posterior.getRandomVars().size() > 1) {
	    throw new IllegalArgumentException
		("Answer to query on " + variable + " should be factor on "
		 + "that variable alone, not " + posterior.getRandomVars());
	}

	Potential pot = posterior.getPotential();
	Type type = pot.getDims().get(0);
	histogram.clear();
	for (Object o : type.getGuaranteedObjects()) {
	    histogram.increaseWeight
		(o, pot.getValue(Collections.singletonList(o)));
	}
    }

    public void zeroOut( ){
	//	histList.add(histogram);
	//	histogram = new Histogram();
	trialNum++;
	if ((outputFile != null) && (trialNum != Main.numTrials())) {
	    outputFile = Main.filePrintStream(Main.histOut() +
					      "-trial" + trialNum + ".data");
	}
	outputFiles = new HashMap();

	histogram.clear();

	// We don't record across-run statistics
    }
 
    public void printVarianceResults(PrintStream s){
	s.println("\tVariance of " + argSpec + " results is not computed.");
	//printVarStats(s);
    }

    

Print the query summary: mean histogram (because we have to calculate it anyway, variation distance (sum of absolute values of each element from its mean), and element-wise variance. Yes, I know this makes many, many passes over histList, but variation distance needs to know the mean first, so I might as well change as little code as possible. private void printVarStats(PrintStream s) { Set objs = new TreeSet(); for (int i = 0; i < histList.size() ; i++) { objs.addAll(((Histogram) histList.get(i)).elementSet()); } double vardist = 0; Histogram meanHist = findMeanHist(objs); Histogram sumSquareHist = new Histogram(); for (int i = 0 ; i < histList.size() ; i++) { Histogram hist = (Histogram) histList.get(i); Iterator objIter = hist.elementSet().iterator(); while (objIter.hasNext()) { Object obj = objIter.next(); double normWeight = hist.getWeight(obj)/hist.getTotalWeight(); vardist += Math.abs(meanHist.getWeight(obj) - normWeight); sumSquareHist.increaseWeight(obj, normWeightnormWeight); } } s.println("Query summary for: " + argSpec); s.println(" Mean histogram: (value, probability)"); Iterator objIter = objs.iterator(); while (objIter.hasNext()) { Object obj = objIter.next(); s.println("t" + obj + "t" + meanHist.getWeight(obj)); } s.println(" Variation distance: " + vardist); s.println(" Element-wise variance: (value, variance, standard dev)"); objIter = objs.iterator(); while (objIter.hasNext()) { Object obj = objIter.next(); double mu2 = Math.pow(meanHist.getWeight(obj), 2); double var = Math.abs(sumSquareHist.getWeight(obj) - mu2); s.println("t" + obj + "t" + var + "t" + Math.sqrt(var)); } } / findMeanHist returns a Histogram representing the mean values for each of the elements in histList. It takes as argument a Set of Objects, objs, which is a set of all objects in all Histograms in histList. private Histogram findMeanHist(Set objs) { Histogram meanHist = new Histogram(); for (int i = 0 ; i < histList.size() ; i++) { Histogram hist = (Histogram) histList.get(i); double denom = hist.getTotalWeight(); Iterator objIter = hist.elementSet().iterator(); while (objIter.hasNext()) { Object obj = objIter.next(); double numer = hist.getWeight(obj); meanHist.increaseWeight(obj, numer/denom); } } // Normalize mean probabilities Histogram normMean = new Histogram(); Iterator objIter = objs.iterator(); double denom = meanHist.getTotalWeight(); while (objIter.hasNext()) { Object obj = objIter.next(); normMean.increaseWeight(obj, meanHist.getWeight(obj)/denom); } return normMean; }



    

Every object should have an output file. If it does not yet exist, create one; otherwise return it.


    private PrintStream getOutputFile(Object o) {
	PrintStream s = (PrintStream) outputFiles.get(o);
	if (s == null) {
	    s = Main.filePrintStream(Main.outputPath() + "-trial" + trialNum +
				     "." + o.toString() + ".data");
	    outputFiles.put(o, s);
	}
	return s;
    }

    public Histogram getHistogram() {
    	return histogram;
    }
    
    public Object getLocation() {
	return argSpec.getLocation();
    }

    public String toString() {
	if (variable == null) {
	    return argSpec.toString();
	}
	return variable.toString();
    }

    protected ArgSpec argSpec;
    protected BayesNetVar variable;
    protected Histogram histogram = new Histogram();
    //    private LinkedList histList = new LinkedList();  // of Histogram
    protected int trialNum = 0;

    protected Map outputFiles = new HashMap();  // of PrintStream
    protected PrintStream outputFile = null;
}


This file was generated on Tue Jun 08 17:53:36 PDT 2010 from file ArgSpecQuery.java
by the ilog.language.tools.Hilite Java tool written by Hassan Aït-Kaci