How to implement Karger's min cut algorithm in Java? - java

I am trying to implement Karger's min cut algorithm but I am unable to get the correct answer. Can someone please have a look at my code and help me figure out what I am doing wrong? Would really appreciate the help.
package a3;
import java.util.*;
import java.io.*;
public class Graph {
private ArrayList<Integer>[] adjList;
private int numOfVertices = 0;
private int numOfEdges = 0;
public Graph(String file) throws IOException {
FileInputStream wordsFile = new FileInputStream(file);
BufferedReader br = new BufferedReader(new InputStreamReader(wordsFile));
adjList = (ArrayList<Integer>[]) new ArrayList[201];
for (int i = 1; i < 201; i++) {
adjList[i] = new ArrayList<Integer>();
}
while (true) {
String s = br.readLine();
if (s == null)
break;
String[] tokens = s.split("\t");
int vertex = Integer.parseInt(tokens[0]);
this.numOfVertices++;
for (int i = 1; i < tokens.length; i++) {
Integer edge = Integer.parseInt(tokens[i]);
addEdge(vertex, edge);
this.numOfEdges++;
}
}
}
public void addEdge(int v, int w) {
adjList[v].add(w);
//this.numOfEdges++;
}
public ArrayList<Integer> getNeighbors(Integer v) {
return adjList[v];
}
public boolean hasEdge(Integer i, Integer j) {
return adjList[i].contains(j);
}
public boolean removeEdge(Integer i, Integer j) {
if (hasEdge(i, j)) {
adjList[i].remove(j);
adjList[j].remove(i);
this.numOfEdges -= 2;
return true;
}
return false;
}
public int getRandomVertex(){
Random rand = new Random();
return (rand.nextInt(this.getNumOfVertices()) + 1);
}
//Returns an array which consists of vertices connected by chosen edge
public int[] getRandomEdge(){
int arr[] = new int[2];
arr[0] = this.getRandomVertex();
while (adjList[arr[0]].size() == 0) {
arr[0] = this.getRandomVertex();
}
Random rand = new Random();
arr[1] = adjList[arr[0]].get(rand.nextInt(adjList[arr[0]].size()));
return arr;
}
//Algorithm for min cut
public int minCut() {
while (this.getNumOfVertices() > 2) {
int[] edge = this.getRandomEdge();
this.removeEdge(edge[0], edge[1]);
//Adding edges of second vertex to first vertex
for (Integer v: adjList[edge[1]]) {
if (!adjList[edge[0]].contains(v)) {
addEdge(edge[0], v);
}
}
//Removing edges of second vertex
for (Iterator<Integer> it = adjList[edge[1]].iterator(); it.hasNext();) {
Integer v = it.next();
it.remove();
this.numOfEdges--;
}
//Removing self-loops
for (Iterator<Integer> it = adjList[edge[0]].iterator(); it.hasNext();) {
Integer v = it.next();
if (v == edge[0])
it.remove();
//this.numOfEdges--;
}
this.numOfVertices--;
}
return this.numOfEdges;
}
public int getNumOfVertices() {
return this.numOfVertices;
}
public int getNumOfEdges() {
return (this.numOfEdges) / 2;
}
public String toString() {
String s = "";
for (int v = 1; v < 201; v++) {
s += v + ": ";
for (int e : adjList[v]) {
s += e + "-> ";
}
s += null + "\n";
//s += "\n";
}
return s;
}
/**
* #param args
* #throws IOException
*/
public static void main(String[] args) throws IOException {
// TODO Auto-generated method stub
int min = 1000;
//Graph test = new Graph("C:\\Users\\UE\\Desktop\\kargerMinCut.txt");
for (int i = 0; i < 10; i++) {
Graph test = new Graph("C:\\Users\\UE\\Desktop\\kargerMinCut.txt");
int currMin = test.minCut();
min = Math.min( min, currMin );
}
System.out.println(min);
}
}

Related

Binary heap output not as expected

I have a homework that the teacher test if it's corrects by checking it's output using this website moodle.caseine.org, so to test my code the program execute these lines and compare the output with the expected one, this is the test :
Tas t = new Tas();
Random r = new Random(123);
for(int i =0; i<10000;i++)t.inser(r.nextInt());
for(int i =0;i<10000;i++)System.out.println(t.supprMax());
System.out.println(t);
And my Heap (Tas) class:
package td1;
import java.util.ArrayList;
import java.util.List;
public class Tas {
private List<Integer> t;
public Tas() {
t = new ArrayList<>();
}
public Tas(ArrayList<Integer> tab) {
t = new ArrayList<Integer>(tab);
}
public static int getFilsGauche(int i) {
return 2 * i + 1;
}
public static int getFilsDroit(int i) {
return 2 * i + 2;
}
public static int getParent(int i) {
return (i - 1) / 2;
}
public boolean estVide() {
return t.isEmpty();
}
#Override
public String toString() {
String str = "";
int size = t.size();
if (size > 0) {
str += "[" + t.get(0);
str += toString(0);
str += "]";
}
return str;
}
public boolean testTas() {
int size = t.size();
int check = 0;
if (size > 0) {
for (int i = 0; i < t.size(); i++) {
if (getFilsGauche(i) < size) {
if (t.get(i) < t.get(getFilsGauche(i))) {
check++;
}
}
if (getFilsDroit(i) < size) {
if (t.get(i) < t.get(getFilsDroit(i))) {
check++;
}
}
}
}
return check == 0;
}
public String toString(int i) {
String str = "";
int size = t.size();
if (getFilsGauche(i) < size) {
str += "[";
str += t.get(getFilsGauche(i));
str += toString(getFilsGauche(i));
str += "]";
}
if (getFilsDroit(i) < size) {
str += "[";
str += t.get(getFilsDroit(i));
str += toString(getFilsDroit(i));
str += "]";
}
return str;
}
//insert value and sort
public void inser(int value) {
t.add(value);
int index = t.size() - 1;
if (index > 0) {
inserCheck(index); // O(log n)
}
}
public void inserCheck(int i) {
int temp = 0;
int parent = getParent(i);
if (parent >= 0 && t.get(i) > t.get(parent)) {
temp = t.get(parent);
t.set(parent, t.get(i));
t.set(i, temp);
inserCheck(parent);
}
}
//switch position of last element is list with first (deletes first and return it)
public int supprMax() {
int size = t.size();
int max = 0;
if (size > 0) {
max = t.get(0);
t.set(0, t.get(size - 1));
t.remove(size - 1);
supprMax(0);
}
else {
throw new IllegalStateException();
}
return max;
}
public void supprMax(int i) {
int size = t.size();
int temp = 0;
int index = i;
if (getFilsGauche(i) < size && t.get(getFilsGauche(i)) > t.get(index)) {
index = getFilsGauche(i);
}
if (getFilsDroit(i) < size && t.get(getFilsDroit(i)) > t.get(index)) {
index = getFilsDroit(i);
}
if (index != i) {
temp = t.get(index);
t.set(index, t.get(i));
t.set(i, temp);
supprMax(index);
}
}
public static void tri(int[] tab) {
Tas tas = new Tas();
for (int i = 0; i < tab.length; i++) {
tas.inser(tab[i]);
}
for (int i = 0; i < tab.length; i++) {
tab[i] = tas.supprMax();
}
}
}
The last 3 lines of the test are :
-2145024521
-2147061786
-2145666206
But the last 3 of my code are :
-2145024521
-2145666206
-2147061786
The problem are probably with the inser and supprMax methods.
I hate to get a bad grade just because of 3 lines placement, because it is a program that verify the code, it dosn't care the the solution was close, it's still says it's wrong.

Java program using union-find + treemap too slow

I am solving this challenge: https://open.kattis.com/problems/virtualfriends
My solution seems to be working but kattis's test cases are running too slowly so I was wondering how I can improve code efficiency. I am using a custom made union-find structure to do this, storing "friends" into a treemap to reference.
import java.util.*;
public class virtualfriends {
public static void main(String[] args) {
Scanner scan = new Scanner(System.in);
int testcases = scan.nextInt();
scan.nextLine();
for (int i= 0; i < testcases; i++) {
int numFriendships = scan.nextInt();
scan.nextLine();
TreeMap<String , Integer> map = new TreeMap<String , Integer>();
int cnt = 0;
UF unionFind = new UF(50000);
for (int j = 0; j < numFriendships; j++)
{
String p1 = scan.next();
String p2 = scan.next();
if (!map.containsKey(p1)) map.put(p1, cnt++);
if (!map.containsKey(p2)) map.put(p2, cnt++);
unionFind.unify(map.get(p1), map.get(p2));
System.out.printf("%d\n", unionFind.getSetSize(map.get(p2)));
}
}
}
static class UF{
private int[] id, setSize;
private int numSets;
public UF(int size) {
id = new int[size] ;
setSize = new int[size];
numSets = size;
for(int i = 0 ; i < size ; ++i) {
id[i] = i;
setSize[i] = 1;
}
}
int find(int i )
{
int root = i;
while (root != id[root]) {
root = id[root];
}
while (i != root) {
int newp = id[i];
id[i] = root;
i = newp;
}
return root;
}
boolean isConnected(int i , int j) {
return find(i) == find(j);
}
int getNumSets() {
return numSets;
}
int getSetSize(int i) {
return setSize[find(i)];
}
boolean isSameSet(int i, int j) {
return find(i) == find(j);
}
void unify(int i, int j)
{
int root1 = find(i);
int root2 = find(j);
if (root1 == root2) return;
if (setSize[root1] < setSize[root2])
{
setSize[root2] += setSize[root1];
id[root1] = root2;
} else {
setSize[root1] += setSize[root2];
id[root2] = root1;
}
numSets--;
}
}
}

hashtable In C giving wrong output

I have created hashtable to solve one of the problem on hackerearth using java.
link:https://www.hackerearth.com/practice/data-structures/hash-tables/basics-of-hash-tables/practice-problems/algorithm/mind-palaces-3/
My java solution is able to pass all the test cases.
Now with same logic ,I am creating solution in C.
But by solution in C does not pass all test cases.
I am learning C.
Please help me to find the problem in C code.
Thanks.
My Java Code:
package hashtable;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;
public class MindPalces {
public static void main(String[] args) throws NumberFormatException,
IOException {
BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
String string = br.readLine();
String strArr[] = string.split(" ");
int n = Integer.parseInt(strArr[0]);
int m = Integer.parseInt(strArr[1]);
SinglyLinkLists hashTable[] = new SinglyLinkLists[337];
long array[][] = new long[n][m];
for (int i = 0; i < n; i++) {
String string1 = br.readLine();
String[] innerArr = string1.split(" ");
for (int j = 0; j < m; j++) {
array[i][j] = Long.parseLong((innerArr[j]));
}
}
for (int i = 0; i < n; i++) {
for (int j = 0; j < m; j++) {
long data = array[i][j];
if (data < 0) {
data = data * -1;
}
int index = Hashs.hashFunction(data);
if (hashTable[index] == null) {
SinglyLinkLists linkList = new SinglyLinkLists();
linkList.insertAtEnd(i, j, array[i][j]);
hashTable[index] = linkList;
} else {
SinglyLinkLists linkList = hashTable[index];
linkList.insertAtEnd(i, j, array[i][j]);
}
}
}
int q = Integer.parseInt(br.readLine());
for (int i = 0; i < q; i++) {
long val = Long.parseLong(br.readLine());
long val1 = val;
if (val < 0) {
val = val * -1;
}
int index = Hashs.hashFunction(val);
if (hashTable[index] == null) {
System.out.println("-1" + " " + "-1");
} else {
SinglyLinkLists linkList = hashTable[index];
NOde list = linkList.getNode(val1);
if (null != list) {
System.out.println(list.getI1() + " " + list.getJ1());
} else {
System.out.println("-1" + " " + "-1");
}
}
}
}
}
class Hashs {
public static int hashFunction(long key) {
return (int) (key % 337);
}
}
class NOde {
int i1 = 0;
int j1 = 0;
NOde link = null;
long number = 0;
public NOde() {
link = null;
i1 = 0;
j1 = 0;
number = 0;
}
public NOde(NOde node, int i1, int j1, long number) {
this.i1 = i1;
this.j1 = j1;
link = node;
this.number = number;
}
public NOde getLink() {
return link;
}
public void setLink(NOde link) {
this.link = link;
}
public int getI1() {
return i1;
}
public void setI1(int i1) {
this.i1 = i1;
}
public int getJ1() {
return j1;
}
public void setJ1(int j1) {
this.j1 = j1;
}
public long getNumber() {
return number;
}
public void setNumber(long number) {
this.number = number;
}
}
class SinglyLinkLists {
NOde start = null;
NOde end = null;
int size = 0;
public SinglyLinkLists() {
start = null;
end = null;
size = 0;
}
public void insertAtEnd(int i1, int j1, long number) {
NOde nptr = new NOde(null, i1, j1, number);
size++;
if (start == null) {
start = nptr;
end = start;
} else {
end.setLink(nptr);
end = nptr;
}
}
// Function to display elements
public NOde getNode(long val) {
if (start.getNumber() == val) {
return start;
}
NOde ptr = start;
ptr = start.getLink();
while (null != ptr && ptr.getLink() != null) {
if (ptr.getNumber() == val) {
return ptr;
}
ptr = ptr.getLink();
}
if (null != ptr && ptr.getNumber() == val) {
return ptr;
}
return null;
}
}
My C Code:
#include<stdio.h>
#include <stdlib.h>
struct Node
{
int i1;
int j1;
long inputVal;
struct Node * next;
};
struct hash
{
struct Node * head;
};
struct hash *hashTable=NULL;
struct Node *createNode(int i,int j,long number)
{
struct Node *list;
list = (struct Node *) malloc(sizeof(struct Node));
list->i1 = i;
list->j1 = j;
list->inputVal=number;
list->next = NULL;
return list;
};
int main()
{
int m=0,n=0;
scanf("%d%d",&m,&n);
hashTable=(struct hash *)calloc(337,sizeof(struct Node));
for(int i=0; i<n; i++)
{
for(int j=0; j<m; j++)
{
long data=0;
long copy=0;
scanf("%ld",&data);
copy=data;
if(data<0)
{
data=data*-1;
}
int index=hashFunction(data);
struct Node *newnode = createNode(i,j,copy);
struct Node *myNode = hashTable[index].head;
if (!myNode)
{
hashTable[index].head=newnode;
}
else
{
newnode->next=hashTable[index].head;
hashTable[index].head=newnode;
}
}
}
int q=0;
scanf("%d",&q);
for(int i=0; i<q; i++)
{
long val=0;
long copy=0;
int boolean=0;
scanf("%ld",&val);
copy=val;
if(val<0)
{
val=val*-1;
}
int index=hashFunction(val);
struct Node *myNode = hashTable[index].head;
if(!myNode)
{
printf("%d %d\n",-1,-1);
}
else
{
while (myNode->next!= NULL)
{
if (myNode->inputVal==copy)
{
boolean=1;
printf("%d %d\n",myNode->i1,myNode->j1);
break;
}
myNode = myNode->next;
}
if(myNode->inputVal==copy &&!boolean)
{
boolean=1;
printf("%d %d\n",myNode->i1,myNode->j1);
}
if(!boolean)
{
printf("%d %d\n",-1,-1);
}
}
}
return 0;
}
int hashFunction(long data)
{
return (int)(data%337);
}
Sample Input:
5 5
-993655555 -758584352 -725954642 -696391700 -649643547
-591473088 -568010221 -432112275 -421496588 -351507172
-323741602 -232192004 -30134637 -369573 100246476
156824549 174266331 392354039 601294716 763826005
768378344 802829330 818988557 992012759 999272829
10
156824549
-758584352
-993655555
601294716
-696391700
802829330
-993655555
-232192004
392354039
-568010221

Knuth-Morris-Pratt algorithm confusion

I'm trying to implement KMP algorithm. My algorithm works correctly with the following example
Text: 121121
Pattern: 121
Result: 1,4
But when Text is 12121 and pattern is the same as above, result just: 1. I don't know if this is the problem of the algorithm or of my implementation?
Other example:
Text: 1111111111
Pattern: 111
Result: 1,4,7
My code is:
public class KMP {
public static void main(String[] args) throws IOException{
BufferedReader reader = new BufferedReader(new InputStreamReader(System.in));
String text = reader.readLine();
String pattern = reader.readLine();
search(text,pattern);
}
private static void search(String text,String pattern)
{
int[] Pi = Pi(pattern);
for (int i = 0,q=0; i <text.length()&&q<pattern.length() ; i++,q++) {
while (q>=0 && pattern.charAt(q)!=text.charAt(i))
{
q=Pi[q];
}
if(q==pattern.length()-1) {
System.out.println(i-pattern.length()+2);
q=Pi[q];
}
}
}
private static int[] Pi(String p) {
int[] Pi = new int[p.length()];
Pi[0]=-1;
int i=0;
int j=-1;
while (i<p.length()-1) {
while (j>=0 && p.charAt(j)!=p.charAt(i))
{
j=Pi[j];
}
i++;
j++;
if(p.charAt(j)==p.charAt(i)) Pi[i]=Pi[j];
else Pi[i]=j;
}
return Pi;
}
}
Hope help you.
public int strStr(String source, String target) {
if (source == null || target == null){
return -1;
}
if (source.isEmpty() && !target.isEmpty()){
return -1;
}
if (source.isEmpty() && target.isEmpty()){
return 0;
}
if (target.isEmpty()){
return 0;
}
int index = 0;
int compare_index = 0;
int compare_start_index = 0;
int compare_same_length = 0;
List<Integer> answers = new ArrayList<Integer>();
while (true){
if (compare_same_length ==0){
compare_start_index = compare_index;
}
if (source.charAt(compare_index) == target.charAt(index)){
compare_same_length++;
index++;
} else {
if (compare_same_length >0){
compare_index--;
}
compare_same_length = 0;
index = 0;
}
compare_index++;
if (compare_same_length == target.length()){
answers.add(compare_start_index+1);
compare_same_length=0;
index=0;
}
if (compare_index == source.length()){
//here are answers
for (int i = 0; i < answers.size(); i++) {
int value = answers.get(i);
}
return 1;
}
}
}

Viterbi algorithm in java

I'm doing the coursera NLP course and the first programming assignment is to build a Viterbi decoder. I think I'm really close to finishing it but there is some elusive bug which I cannot seem to be able to trace. Here is my code:
http://pastie.org/private/ksmbns3gjctedu1zxrehw
http://pastie.org/private/ssv6tc8dwnamn2qegdvww
So far I've debugged the "teaching" related functions so I can say that the parameters for the algorithms are being correctly estimated. Of particular interest is the viterbi() and findW() methods. The definition of the algorithm I'm using can be found here: http://www.cs.columbia.edu/~mcollins/hmms-spring2013.pdf on page 18.
One thing which I'm having hard time wrapping my head around is how am I supposed to update the backpointers for the special cases when K = {1, 2} (in my case this is 0 and 1, since I'm zero-indexing my array) respectively the parameters I'm using in those cases are q({TAGSET} | *, *) and q ({TAGSET} | *, {TAGSET}).
Hints rather than spoon-fed answers will also be highly appreciated!
Here's a simple implementation of viterbi decoder by Yusuke Shunyama =) http://cs.nyu.edu/yusuke/course/NLP/viterbi/Viterbi.java
/*
* Viterbi.java
* Toy Viterbi Decorder
*
* by Yusuke Shinyama <yusuke at cs . nyu . edu>
*
* Permission to use, copy, modify, distribute this software
* for any purpose is hereby granted without fee, provided
* that the above copyright notice appear in all copies and
* that both that copyright notice and this permission notice
* appear in supporting documentation.
*/
import java.awt.*;
import java.util.*;
import java.text.*;
import java.awt.event.*;
import java.applet.*;
class Symbol {
public String name;
public Symbol(String s) {
name = s;
}
}
class SymbolTable {
Hashtable table;
public SymbolTable() {
table = new Hashtable();
}
public Symbol intern(String s) {
s = s.toLowerCase();
Object sym = table.get(s);
if (sym == null) {
sym = new Symbol(s);
table.put(s, sym);
}
return (Symbol)sym;
}
}
class SymbolList {
Vector list;
public SymbolList() {
list = new Vector();
}
public int size() {
return list.size();
}
public void set(int index, Symbol sym) {
list.setElementAt(sym, index);
}
public void add(Symbol sym) {
list.addElement(sym);
}
public Symbol get(int index) {
return (Symbol) list.elementAt(index);
}
}
class IntegerList {
Vector list;
public IntegerList() {
list = new Vector();
}
public int size() {
return list.size();
}
public void set(int index, int i) {
list.setElementAt(new Integer(i), index);
}
public void add(int i) {
list.addElement(new Integer(i));
}
public int get(int index) {
return ((Integer)list.elementAt(index)).intValue();
}
}
class ProbTable {
Hashtable table;
public ProbTable() {
table = new Hashtable();
}
public void put(Object obj, double prob) {
table.put(obj, new Double(prob));
}
public double get(Object obj) {
Double prob = (Double)table.get(obj);
if (prob == null) {
return 0.0;
}
return prob.doubleValue();
}
// normalize probability
public void normalize() {
double total = 0.0;
for(Enumeration e = table.elements() ; e.hasMoreElements() ;) {
total += ((Double)e.nextElement()).doubleValue();
}
if (total == 0.0) {
return; // div by zero!
}
for(Enumeration e = table.keys() ; e.hasMoreElements() ;) {
Object k = e.nextElement();
double prob = ((Double)table.get(k)).doubleValue();
table.put(k, new Double(prob / total));
}
}
}
class State {
public String name;
ProbTable emits;
ProbTable linksto;
public State(String s) {
name = s;
emits = new ProbTable();
linksto = new ProbTable();
}
public void normalize() {
emits.normalize();
linksto.normalize();
}
public void addSymbol(Symbol sym, double prob) {
emits.put(sym, prob);
}
public double emitprob(Symbol sym) {
return emits.get(sym);
}
public void addLink(State st, double prob) {
linksto.put(st, prob);
}
public double transprob(State st) {
return linksto.get(st);
}
}
class StateTable {
Hashtable table;
public StateTable() {
table = new Hashtable();
}
public State get(String s) {
s = s.toUpperCase();
State st = (State)table.get(s);
if (st == null) {
st = new State(s);
table.put(s, st);
}
return st;
}
}
class StateIDTable {
Hashtable table;
public StateIDTable() {
table = new Hashtable();
}
public void put(State obj, int i) {
table.put(obj, new Integer(i));
}
public int get(State obj) {
Integer i = (Integer)table.get(obj);
if (i == null) {
return 0;
}
return i.intValue();
}
}
class StateList {
Vector list;
public StateList() {
list = new Vector();
}
public int size() {
return list.size();
}
public void set(int index, State st) {
list.setElementAt(st, index);
}
public void add(State st) {
list.addElement(st);
}
public State get(int index) {
return (State) list.elementAt(index);
}
}
class HMMCanvas extends Canvas {
static final int grid_x = 60;
static final int grid_y = 40;
static final int offset_x = 70;
static final int offset_y = 30;
static final int offset_y2 = 10;
static final int offset_y3 = 65;
static final int col_x = 40;
static final int col_y = 10;
static final int state_r = 10;
static final Color state_fill = Color.white;
static final Color state_fill_maximum = Color.yellow;
static final Color state_fill_best = Color.red;
static final Color state_boundery = Color.black;
static final Color link_normal = Color.green;
static final Color link_processed = Color.blue;
static final Color link_maximum = Color.red;
HMMDecoder hmm;
public HMMCanvas() {
setBackground(Color.white);
setSize(400,300);
}
public void setHMM(HMMDecoder h) {
hmm = h;
}
private void drawState(Graphics g, int x, int y, Color c) {
x = x * grid_x + offset_x;
y = y * grid_y + offset_y;
g.setColor(c);
g.fillOval(x-state_r, y-state_r, state_r*2, state_r*2);
g.setColor(state_boundery);
g.drawOval(x-state_r, y-state_r, state_r*2, state_r*2);
}
private void drawLink(Graphics g, int x, int y0, int y1, Color c) {
int x0 = grid_x * x + offset_x;
int x1 = grid_x * (x+1) + offset_x;
y0 = y0 * grid_y + offset_y;
y1 = y1 * grid_y + offset_y;
g.setColor(c);
g.drawLine(x0, y0, x1, y1);
}
private void drawCenterString(Graphics g, String s, int x, int y) {
x = x - g.getFontMetrics().stringWidth(s)/2;
g.setColor(Color.black);
g.drawString(s, x, y+5);
}
private void drawRightString(Graphics g, String s, int x, int y) {
x = x - g.getFontMetrics().stringWidth(s);
g.setColor(Color.black);
g.drawString(s, x, y+5);
}
public void paint(Graphics g) {
if (hmm == null) {
return;
}
DecimalFormat form = new DecimalFormat("0.0000");
int nsymbols = hmm.symbols.size();
int nstates = hmm.states.size();
// complete graph.
for(int i = 0; i < nsymbols; i++) {
int offset_ymax = offset_y2+nstates*grid_y;
if (i < nsymbols-1) {
for(int y1 = 0; y1 < nstates; y1++) {
for(int y0 = 0; y0 < nstates; y0++) {
Color c = link_normal;
if (hmm.stage == i+1 && hmm.i0 == y0 && hmm.i1 == y1) {
c = link_processed;
}
if (hmm.matrix_prevstate[i+1][y1] == y0) {
c = link_maximum;
}
drawLink(g, i, y0, y1, c);
if (c == link_maximum && 0 < i) {
double transprob = hmm.states.get(y0).transprob(hmm.states.get(y1));
drawCenterString(g, form.format(transprob),
offset_x + i*grid_x + grid_x/2, offset_ymax);
offset_ymax = offset_ymax + 16;
}
}
}
}
// state circles.
for(int y = 0; y < nstates; y++) {
Color c = state_fill;
if (hmm.matrix_prevstate[i][y] != -1) {
c = state_fill_maximum;
}
if (hmm.sequence.size() == nsymbols &&
hmm.sequence.get(nsymbols-1-i) == y) {
c = state_fill_best;
}
drawState(g, i, y, c);
}
}
// max probability.
for(int i = 0; i < nsymbols; i++) {
for(int y1 = 0; y1 < nstates; y1++) {
if (hmm.matrix_prevstate[i][y1] != -1) {
drawCenterString(g, form.format(hmm.matrix_maxprob[i][y1]),
offset_x+i*grid_x, offset_y+y1*grid_y);
}
}
}
// captions (symbols atop)
for(int i = 0; i < nsymbols; i++) {
drawCenterString(g, hmm.symbols.get(i).name, offset_x+i*grid_x, col_y);
}
// captions (states in left)
for(int y = 0; y < nstates; y++) {
drawRightString(g, hmm.states.get(y).name, col_x, offset_y+y*grid_y);
}
// status bar
g.setColor(Color.black);
g.drawString(hmm.status, col_x, offset_y3+nstates*grid_y);
g.drawString(hmm.status2, col_x, offset_y3+nstates*grid_y+16);
}
}
class HMMDecoder {
StateList states;
int state_start;
int state_end;
public IntegerList sequence;
public double[][] matrix_maxprob;
public int[][] matrix_prevstate;
public SymbolList symbols;
public double probmax;
public int stage, i0, i1;
public boolean laststage;
public String status, status2;
public HMMDecoder() {
status = "Not initialized.";
status2 = "";
states = new StateList();
}
public void addStartState(State st) {
state_start = states.size(); // get current index
states.add(st);
}
public void addNormalState(State st) {
states.add(st);
}
public void addEndState(State st) {
state_end = states.size(); // get current index
states.add(st);
}
// for debugging.
public void showmatrix() {
for(int i = 0; i < symbols.size(); i++) {
for(int j = 0; j < states.size(); j++) {
System.out.print(matrix_maxprob[i][j]+" "+matrix_prevstate[i][j]+", ");
}
System.out.println();
}
}
// initialize for decoding
public void initialize(SymbolList syms) {
// symbols[syms.length] should be END
symbols = syms;
matrix_maxprob = new double[symbols.size()][states.size()];
matrix_prevstate = new int[symbols.size()][states.size()];
for(int i = 0; i < symbols.size(); i++) {
for(int j = 0; j < states.size(); j++) {
matrix_prevstate[i][j] = -1;
}
}
State start = states.get(state_start);
for(int i = 0; i < states.size(); i++) {
matrix_maxprob[0][i] = start.transprob(states.get(i));
matrix_prevstate[0][i] = 0;
}
stage = 0;
i0 = -1;
i1 = -1;
sequence = new IntegerList();
status = "Ok, let's get started...";
status2 = "";
}
// forward procedure
public boolean proceed_decoding() {
status2 = "";
// already end?
if (symbols.size() <= stage) {
return false;
}
// not started?
if (stage == 0) {
stage = 1;
i0 = 0;
i1 = 0;
matrix_maxprob[stage][i1] = 0.0;
} else {
i0++;
if (states.size() <= i0) {
// i0 should be reinitialized.
i0 = 0;
i1++;
if (states.size() <= i1) {
// i1 should be reinitialized.
// goto next stage.
stage++;
if (symbols.size() <= stage) {
// done.
status = "Decoding finished.";
return false;
}
laststage = (stage == symbols.size()-1);
i1 = 0;
}
matrix_maxprob[stage][i1] = 0.0;
}
}
// sym1: next symbol
Symbol sym1 = symbols.get(stage);
State s0 = states.get(i0);
State s1 = states.get(i1);
// precond: 1 <= stage.
double prob = matrix_maxprob[stage-1][i0];
DecimalFormat form = new DecimalFormat("0.0000");
status = "Prob:" + form.format(prob);
if (1 < stage) {
// skip first stage.
double transprob = s0.transprob(s1);
prob = prob * transprob;
status = status + " x " + form.format(transprob);
}
double emitprob = s1.emitprob(sym1);
prob = prob * emitprob;
status = status + " x " + form.format(emitprob) + "(" + s1.name+":"+sym1.name + ")";
status = status + " = " + form.format(prob);
// System.out.println("stage: "+stage+", i0:"+i0+", i1:"+i1+", prob:"+prob);
if (matrix_maxprob[stage][i1] < prob) {
matrix_maxprob[stage][i1] = prob;
matrix_prevstate[stage][i1] = i0;
status2 = "(new maximum found)";
}
return true;
}
// backward proc
public void backward() {
int probmaxstate = state_end;
sequence.add(probmaxstate);
for(int i = symbols.size()-1; 0 < i; i--) {
probmaxstate = matrix_prevstate[i][probmaxstate];
if (probmaxstate == -1) {
status2 = "Decoding failed.";
return;
}
sequence.add(probmaxstate);
//System.out.println("stage: "+i+", state:"+probmaxstate);
}
}
}
public class Viterbi extends Applet implements ActionListener, Runnable {
SymbolTable symtab;
StateTable sttab;
HMMDecoder myhmm = null;
HMMCanvas canvas;
Panel p;
TextArea hmmdesc;
TextField sentence;
Button bstart, bskip;
static final String initialHMM =
"start: go(cow,1.0)\n" +
"cow: emit(moo,0.9) emit(hello,0.1) go(cow,0.5) go(duck,0.3) go(end,0.2)\n" +
"duck: emit(quack,0.6) emit(hello,0.4) go(duck,0.5) go(cow,0.3) go(end,0.2)\n";
final int sleepmillisec = 100; // 0.1s
// setup hmm
// success:true.
boolean setupHMM(String s) {
myhmm = new HMMDecoder();
symtab = new SymbolTable();
sttab = new StateTable();
State start = sttab.get("start");
State end = sttab.get("end");
myhmm.addStartState(start);
boolean success = true;
StringTokenizer lines = new StringTokenizer(s, "\n");
while (lines.hasMoreTokens()) {
// foreach line.
String line = lines.nextToken();
int i = line.indexOf(':');
if (i == -1) break;
State st0 = sttab.get(line.substring(0,i).trim());
if (st0 != start && st0 != end) {
myhmm.addNormalState(st0);
}
//System.out.println(st0.name+":"+line.substring(i+1));
StringTokenizer tokenz = new StringTokenizer(line.substring(i+1), ", ");
while (tokenz.hasMoreTokens()) {
// foreach token.
String t = tokenz.nextToken().toLowerCase();
if (t.startsWith("go(")) {
State st1 = sttab.get(t.substring(3).trim());
// fetch another token.
if (!tokenz.hasMoreTokens()) {
success = false; // err. nomoretoken
break;
}
String n = tokenz.nextToken().replace(')', ' ');
double prob;
try {
prob = Double.valueOf(n).doubleValue();
} catch (NumberFormatException e) {
success = false; // err.
prob = 0.0;
}
st0.addLink(st1, prob);
//System.out.println("go:"+st1.name+","+prob);
} else if (t.startsWith("emit(")) {
Symbol sym = symtab.intern(t.substring(5).trim());
// fetch another token.
if (!tokenz.hasMoreTokens()) {
success = false; // err. nomoretoken
break;
}
String n = tokenz.nextToken().replace(')', ' ');
double prob;
try {
prob = Double.valueOf(n).doubleValue();
} catch (NumberFormatException e) {
success = false; // err.
prob = 0.0;
}
st0.addSymbol(sym, prob);
//System.out.println("emit:"+sym.name+","+prob);
} else {
// illegal syntax, just ignore
break;
}
}
st0.normalize(); // normalize probability
}
end.addSymbol(symtab.intern("end"), 1.0);
myhmm.addEndState(end);
return success;
}
// success:true.
boolean setup() {
if (! setupHMM(hmmdesc.getText()))
return false;
// initialize words
SymbolList words = new SymbolList();
StringTokenizer tokenz = new StringTokenizer(sentence.getText());
words.add(symtab.intern("start"));
while (tokenz.hasMoreTokens()) {
words.add(symtab.intern(tokenz.nextToken()));
}
words.add(symtab.intern("end"));
myhmm.initialize(words);
canvas.setHMM(myhmm);
return true;
}
public void init() {
canvas = new HMMCanvas();
setLayout(new BorderLayout());
p = new Panel();
sentence = new TextField("moo hello quack", 20);
bstart = new Button(" Start ");
bskip = new Button("Auto");
bstart.addActionListener(this);
bskip.addActionListener(this);
p.add(sentence);
p.add(bstart);
p.add(bskip);
hmmdesc = new TextArea(initialHMM, 4, 20);
add("North", canvas);
add("Center", p);
add("South", hmmdesc);
}
void setup_fallback() {
// adjustable
State cow = sttab.get("cow");
State duck = sttab.get("duck");
State end = sttab.get("end");
cow.addLink (cow, 0.5);
cow.addLink (duck, 0.3);
cow.addLink (end, 0.2);
duck.addLink (cow, 0.3);
duck.addLink (duck, 0.5);
duck.addLink (end, 0.2);
cow.addSymbol(symtab.intern("moo"), 0.9);
cow.addSymbol(symtab.intern("hello"), 0.1);
duck.addSymbol(symtab.intern("quack"), 0.6);
duck.addSymbol(symtab.intern("hello"), 0.4);
}
public void destroy() {
remove(p);
remove(canvas);
}
public void processEvent(AWTEvent e) {
if (e.getID() == Event.WINDOW_DESTROY) {
System.exit(0);
}
}
public void run() {
if (myhmm != null) {
while (myhmm.proceed_decoding()) {
canvas.repaint();
try {
Thread.sleep(sleepmillisec);
} catch (InterruptedException e) {
;
}
}
myhmm.backward();
canvas.repaint();
bstart.setLabel(" Start ");
bstart.setEnabled(true);
bskip.setEnabled(true);
myhmm = null;
}
}
public void actionPerformed(ActionEvent ev) {
String label = ev.getActionCommand();
if (label.equalsIgnoreCase(" start ")) {
if (!setup()) {
// error
return;
}
bstart.setLabel("Proceed");
canvas.repaint();
} else if (label.equalsIgnoreCase("proceed")) {
// next
if (! myhmm.proceed_decoding()) {
myhmm.backward();
bstart.setLabel(" Start ");
myhmm = null;
}
canvas.repaint();
} else if (label.equalsIgnoreCase("auto")) {
// skip
if (myhmm == null) {
if (!setup()) {
// error
return;
}
}
bstart.setEnabled(false);
bskip.setEnabled(false);
Thread me = new Thread(this);
me.setPriority(Thread.MIN_PRIORITY);
// start animation.
me.start();
}
}
public static void main(String args[]) {
Frame f = new Frame("Viterbi");
Viterbi v = new Viterbi();
f.add("Center", v);
f.setSize(400, 400);
f.show();
v.init();
v.start();
}
public String getAppletInfo() {
return "A Sample Viterbi Decoder Applet";
}
}
Here are some suggestions:
You should draw out the HMM lattice if you have any confusion about how transitions are occurring in your model.
For instance, you can view transitions to the first hidden state as originating from a single start hidden state, so the backpointer for k=0 would always have to point to this start state. In your code, you probably should not be looping over hidden states for k=0 in findW (which seems correct), but you probably should be looping for k=1.
Usually multiplying transition and emission probabilities in HMM inference lead to very small floating point values, which can lead to numerical errors. You should add log probabilities instead of multiplying probabilities.
To check viterbi or forward-backward implementations, I usually also write a brute force method and compare the output of each. If the brute-force and dynamic-programming algorithm match on short sequences, then that gives a reasonable measure of confidence that both are correct.

Categories

Resources