/*-
 * Copyright (c) 2009, Alexandre P. Francisco <aplf@ist.utl.pt>
 *
 * Permission to use, copy, modify, and/or distribute this software for any
 * purpose with or without fee is hereby granted, provided that the above
 * copyright notice and this permission notice appear in all copies.
 *
 * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
 * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
 * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
 * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
 * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
 * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
 * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
 */

package rank;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.Iterator;
import java.util.TreeMap;
import java.util.concurrent.TimeUnit;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import cern.colt.list.IntArrayList;
import cern.colt.map.OpenIntDoubleHashMap;

import com.martiansoftware.jsap.FlaggedOption;
import com.martiansoftware.jsap.JSAP;
import com.martiansoftware.jsap.JSAPException;
import com.martiansoftware.jsap.JSAPResult;
import com.martiansoftware.jsap.Parameter;
import com.martiansoftware.jsap.SimpleJSAP;
import com.martiansoftware.jsap.UnflaggedOption;

import it.unimi.dsi.logging.ProgressLogger;
import it.unimi.dsi.webgraph.labelling.BitStreamArcLabelledImmutableGraph;

public class LocalExpander {

	private static Logger LOGGER = LoggerFactory.getLogger(LocalExpander.class.getName());
	public static long LOGINT = 5;
	
	private TreeMap<Integer,OpenIntDoubleHashMap> clustering;
	private double[] clusteringScore;
	
	public LocalExpander(WeightedImmutableGraph graph, TreeMap<Integer,OpenIntDoubleHashMap> clustering,
			double alpha, int k, double beta, double scale) {
		
		this.clustering = clustering;
		this.clusteringScore = new double[clustering.size()];
		
		expand(graph, alpha, k, beta, scale);
	}
	
	private void expand(WeightedImmutableGraph graph, double alpha, int k, double beta, double scale) {
		
		double avg = 0.0;
		int clID = 0;
		
		ProgressLogger pl = new ProgressLogger(LOGGER, LOGINT, TimeUnit.SECONDS);
		pl.start("Expanding " + clustering.size() + " seed sets...");
		pl.expectedUpdates = clustering.size();
		pl.itemsName = "clusters";
		RankMethod lo = new HeatKernel(graph, alpha, k);
		RankMethod lp = new PageRank(graph, alpha, k);
		Iterator<OpenIntDoubleHashMap> iter = clustering.values().iterator();
		while (iter.hasNext()) {
			OpenIntDoubleHashMap v = iter.next();

			LOGGER.debug("Seed set (" + v.size() + "): " + v);

			OpenIntDoubleHashMap w = lo.compute(v);
			lo.reset(); // Hum... why is not it done by compute?
			LOGGER.debug("Rank vector (" + w.size() + "): " + w);
			OpenIntDoubleHashMap wp = lp.compute(v);
			lp.reset(); // Hum... why is not it done by compute?
			LOGGER.debug("Rank vector (" + wp.size() + "): " + wp);

			//LOGGER.debug("Sweep on: " + w);
			Sweep swp = new Sweep(graph, w, beta, v, scale);

			v.clear();
			int[] s = swp.getSortedIndex();
			
			for (int i = 0; i <= swp.getFirstBestIndex(); i++)
			//for (int i = 0; i <= swp.getOptimalIndex(); i++)
				v.put(s[i], w.get(s[i]));
			v.trimToSize();

			//System.out.println(swp.getFirstBestScore());
			avg += swp.getFirstBestScore();
			//avg += swp.getBestScore();
			
			clusteringScore[clID++] = swp.getFirstBestScore();
			
			//LOGGER.debug("Result: " + v);
			pl.update();
		}
		pl.done();
		
		// Compute average.
		avg /= clustering.size();
		
		LOGGER.info("Average score: " + avg);
	}
	
	public OpenIntDoubleHashMap getCluster(int u) {
		return clustering.get(u);
	}
	
	public TreeMap<Integer,OpenIntDoubleHashMap> getClustering() {
		return clustering;
	} 
	
	public void dumpClustering() {
		
		int clID = 0;
		//System.out.println("clusters:");
		Iterator<OpenIntDoubleHashMap> iter = clustering.values().iterator();
		while (iter.hasNext()) {
			IntArrayList idx = iter.next().keys();
			
			if (idx.size() == 0) {
				System.out.println();
				continue;
			}
			
			System.out.print(clusteringScore[clID ++] + " ");
			
			idx.sortFromTo(0, idx.size() - 1);
			System.out.print(idx.getQuick(0));
			for (int i = 1; i < idx.size(); i++)
				System.out.print(" " + idx.getQuick(i));
			System.out.println();
		}
	}
	
	public static void main(String[] args) {
		
		SimpleJSAP jsap = null;
		JSAPResult jsapResult = null;
		String basename = null;
		double alpha = 1.0;
		int k = 1;
		double beta = 1.0;
		double scale = Double.MAX_VALUE;
		String cores = null;
		
		try {
			jsap = new SimpleJSAP(
					LocalExpander.class.getName(),
					"Computes a soft clustering for the given (weighted) compressed graph.",
					new Parameter[] {
						new FlaggedOption("logInterval", JSAP.LONG_PARSER,
							"5", JSAP.NOT_REQUIRED, 'l', "log-interval",
							"The minimum time interval between activity logs in seconds."),
						new FlaggedOption("alpha", JSAP.DOUBLE_PARSER,
							"1.0", JSAP.REQUIRED, 'a', "alpha",
							"Parameter alpha for the rank method, e.g., heat kernel temperature." ),
						new FlaggedOption("k", JSAP.INTEGER_PARSER,
							"1", JSAP.REQUIRED, 'k', "max-iterations",
							"Maximum number of iterations for the rank method." ),
						new FlaggedOption("beta", JSAP.DOUBLE_PARSER,
							"1.0", JSAP.REQUIRED, 'b', "beta",
							"Score correction factor." ),
						new FlaggedOption("scale", JSAP.DOUBLE_PARSER,
							"1.0", JSAP.NOT_REQUIRED, 's', "scale",
							"Sweep scale factor." ),
				   		new UnflaggedOption("basename", JSAP.STRING_PARSER,
							JSAP.NO_DEFAULT, JSAP.REQUIRED, JSAP.NOT_GREEDY,
							"The basename of the graph."),
						new UnflaggedOption("cores", JSAP.STRING_PARSER,
							JSAP.NO_DEFAULT, JSAP.REQUIRED, JSAP.NOT_GREEDY,
							"The cores filename.")
					});
			
			jsapResult = jsap.parse(args);
			if (jsap.messagePrinted())
				System.exit(0);
			
			LOGGER.info("Parsing arguments...");
			basename = jsapResult.getString("basename");
			alpha = jsapResult.getDouble("alpha");
			k = jsapResult.getInt("k");
			beta = jsapResult.getDouble("beta");
			scale = jsapResult.getDouble("scale");
			if (scale < 0)
				scale = Double.MAX_VALUE;
			LocalExpander.LOGINT = jsapResult.getLong("logInterval");
			cores = jsapResult.getString("cores");

		} catch (JSAPException e) {
			LOGGER.error(e.getMessage());
			System.exit(1);
		}

		LOGGER.info("Completed. ("
				+ LocalExpander.LOGINT + ", "
				+ alpha + ", "
				+ k + ", "
				+ beta + ", "
				+ scale + ", "
				+ basename + ", " 
				+ cores + ")");

		LOGGER.info("Loading graph...");
		WeightedImmutableGraph g = null;
		try {
			g = new WeightedImmutableGraph(BitStreamArcLabelledImmutableGraph.load(basename,
					new ProgressLogger(LOGGER, LOGINT, TimeUnit.SECONDS)));
		} catch (IOException e) {
			System.err.println("Error: Failed to load graph '"+ basename +"'.");
			e.printStackTrace();
			System.exit(1);
		}
		
		LocalExpander clusterIt;
		LOGGER.info("Loading seed sets...");

		TreeMap<Integer,OpenIntDoubleHashMap> clustering = new TreeMap<Integer,OpenIntDoubleHashMap>();

		try {

			BufferedReader br = new BufferedReader(new FileReader(new File(cores)));

			int p = 0;
			String line = null;
			while ((line = br.readLine()) != null) {
				String[] token = line.split(" ");
				OpenIntDoubleHashMap core = new OpenIntDoubleHashMap();

				for (int i = 0; i < token.length; i++)
					core.put(Integer.parseInt(token[i]), 1.0 / token.length);

				clustering.put(p++, core);
			}

			br.close();

		} catch (IOException e) {
			System.err.println("Error: Failed to load cores from '"+ cores +"'.");
			e.printStackTrace();
			System.exit(1);
		}

		LOGGER.info("Init a new ClusterIt instance.");
		clusterIt = new LocalExpander(g, clustering, alpha, k, beta, scale);
		
		clusterIt.dumpClustering();
	}
}
