Stock price predictions using Baum-Welch algorithm
$begingroup$
I am testing my implementation of Baum-Welch algorithm for stock market data. The sequence for the price of a certain stock ends with this:
73 76 76 76 76 75 78 78 76 74 78 78 79 80 83 82 82 81 80 79 78 77 76 77 76 75 77 78 79 80 78 75 75 76 75 76 77 76 75 76 81
Now I try to make my model estimate the last observation:
int obs = { 73, 76, 76, 76, 76, 75, 78, 78, 76, 74, 78, 78, 79, 80, 83, 82, 82,81, 80, 79, 78, 77, 76,77, 76, 75, 77, 78, 79, 80, 78, 75, 75, 76, 75, 76, 77, 76, 75, 76};
BaumWelch model = new BaumWelch(100, 100);
model.run(obs);
double optimumP = 0;
int estimate = -1;
double transitions = Util.normalize(model.getAlpha(obs)); // FIXME: Explain
double nextObservations = Util.normalize(model.getNextObservationProbabilities(transitions));
double maxP = 0;
int tmp = 0;
for (int j = 0; j < nextObservations.length; j++) {
if (nextObservations[j] > maxP) {
maxP = nextObservations[j];
tmp = j;
}
}
if (maxP > optimumP) {
estimate = tmp;
}
System.out.println(estimate);
The output is 75, not very good. I try to estimate the second to last observation with code exactly like above just removing the last (76) and the output becomes 77 which is good enough. I test once again removing the "75" and see if it can estimate it: Output is 75. I remove the last observation again (76) and see what the program estimates. Output is 76.
So it seems close to making predictions in the right direction most of the time. Naturally I will test this much more, but I am looking for bugs regarding correctness of my implementation of Baum-Welch algorithm.
The actual model code is the following.
import java.lang.Math;
class BaumWelch {
int iterationCount = 30; // Arbitrary count of maximum iterations
int stateCount; // Count of hidden states e.g. number of birds
int observationCount; // Count of observations e.g. nu
double a; // Transition matrix
double b; // Observation matrix
double pi; // Initialization vector
void run(int observations) { // FIXME: Explain the Baum-Welch algorithm
double numer, denom, logP, maxP;
double xi = new double[observations.length][stateCount][stateCount];
double alpha = new double[observations.length][stateCount];
double beta = new double[observations.length][stateCount];
double gamma = new double[observations.length][stateCount];
double scaling = new double[observations.length];
maxP = -Double.MAX_VALUE; // FIXME: Explain initialization of logaritmic probability
for(int count=0; count < iterationCount; count++){
logP = 0;
forward(observations, alpha, scaling);
backward(observations, beta, scaling);
for (int time = 0; time < scaling.length; time++) { //FIXME: Explain
logP = scaling[time] != 0 ? logP + Math.log(scaling[time]) : logP;
}
logP = -logP;
if (logP < maxP) break; //improve or break loop
maxP = logP;
for (int time = 0; time < observations.length - 1; time++) { // Σ observations Σ originstates Σ transitionstates
denom = 0; // Calculate a denominator for each transition between possible states
for (int originState = 0; originState < stateCount; originState++) { // ΣΣ
for (int transitionState = 0; transitionState < stateCount; transitionState++) { // ΣΣ
denom += alpha[time][originState] * a[originState][transitionState] * b[transitionState][observations[time + 1]] * beta[(time + 1)][transitionState]; // FIXME: Explain this is count of possible states in eq. 37 http://www.cs.ubc.ca/~murphyk/Bayes/rabiner.pdf
}
}
for (int originState = 0; originState < stateCount; originState++) {
gamma[time][originState] = 0;
for (int transitionState = 0; transitionState < stateCount; ++transitionState) { // divide the count of good states by the count of possible states
xi[time][originState][transitionState] = denom != 0 ? (alpha[time][originState] * a[originState][transitionState] * b[transitionState][observations[time + 1]] * beta[(time + 1)][transitionState]) / denom : 0;
gamma[time][originState] += xi[time][originState][transitionState]; // formula 38 in Rabiner (89)
}
}
}
denom = 0;
for (int state = 0; state < stateCount; state++) { // FIXME: Explain line below. count of possible states
denom += alpha[observations.length - 1][state] * beta[observations.length - 1][state];
}
for (int state = 0; state < stateCount; state++) { // Calculate gamma for the last step. formula 27 in Rabiner(89)
gamma[observations.length - 1][state] = 0.0;
gamma[observations.length - 1][state] += (alpha[observations.length - 1][state] * beta[observations.length - 1][state]) / denom;
}
/* Update estimates */
for (int originState = 0; originState < stateCount; originState++) {
pi[originState] = gamma[0][originState]; //initial vector
for (int transitionState = 0; transitionState < stateCount; transitionState++) {
numer = 0.0;
denom = 0.0;
for (int t = 0; t < observations.length - 1; ++t) {
numer += xi[t][originState][transitionState];
denom += gamma[t][originState];
}
a[originState][transitionState] = denom != 0 ? numer / denom : a[originState][transitionState]; //transition matrix
}
for (int transitionState = 0; transitionState < observationCount; transitionState++) {
numer = 0.0;
denom = 0.0;
for (int time = 0; time < observations.length; time++) {
if (transitionState == observations[time]) {
numer += gamma[time][originState];
}
denom += gamma[time][originState];
}
b[originState][transitionState] = denom != 0 ? numer / denom : b[originState][transitionState]; //observation matrix
}
}
}
}
/* Calculate the distribution (array) of next step observations */
double getNextObservationProbabilities(double pCurrent) {
double returnProb = new double[observationCount];
for (int currentState = 0; currentState < stateCount; currentState++) {
for (int nextState = 0; nextState < stateCount; nextState++) {
for (int observation = 0; observation < observationCount; observation++) {
returnProb[observation] += pCurrent[nextState] * a[nextState][currentState] * b[currentState][observation];
}
}
}
return returnProb;
}
double getSequenceProbability(int observations) { // Calculate the probability of a given sequence "observations"
double returnProb = 0;
double alpha = new double[observations.length][stateCount]; // FIXME: Explain
for (int i = 0; i < stateCount; i++) {
alpha[0][i] = pi[i] * b[i][observations[0]]; // FIXME: Explain
}
for (int time = 0; ++time < observations.length;) { // FIXME: Explain indexing
for (int i = 0; i < stateCount; i++) {
for (int j = 0; j < stateCount; j++) {
alpha[time][i] += alpha[(time - 1)][j] * a[j][i] * b[i][observations[time]]; // FIXME: Explain
}
}
}
for (int i = 0; i < stateCount; i++) {
returnProb += alpha[(alpha.length - 1)][i];
}
return returnProb;
}
void forward(int observations, double alpha, double scaling) { // FIXME: Explain https://en.wikipedia.org/wiki/Forward_algorithm
scaling[0] = 0;
for (int state = 0; ++state < stateCount;) {
alpha[0][state] = pi[state] * b[state][observations[0]];
scaling[0] += alpha[0][state];
}
scaling[0] = scaling[0] != 0 ? 1.0 / scaling[0] : scaling[0];
for (int state = 0; state < stateCount; state++) {
alpha[0][state] = scaling[0] * alpha[0][state];
}
for (int time = 1; time < observations.length; time++) { // Forward in time
scaling[time] = 0;
for (int state = 0; state < stateCount; state++) {
double sum = 0;
alpha[time][state] = 0;
for (int originState = 0; originState < stateCount; originState++) {
sum += alpha[time - 1][originState] * a[originState][state];
}
alpha[time][state] = sum * b[state][observations[time]];
scaling[time] += sum;
}
scaling[time] = scaling[time] != 0 ? scaling[time] = 1.0 / scaling[time] : scaling[time];
for (int i = 0; i < stateCount; i++) {
alpha[time][i] = scaling[time] * alpha[time][i];
}
}
}
void backward(int observations, double beta, double scaling) { // FIXME: Explain https://en.wikipedia.org/wiki/Forward%E2%80%93backward_algorithm
for (int state = 0; state < stateCount; state++) {
beta[(observations.length - 1)][state] = scaling[observations.length - 1]; // Base-case beta for last observation
}
for (int time = observations.length - 2; time >= 0; time--) { // Backward in time
for (int state = 0; state < stateCount; state++) {
beta[time][state] = 0;
for (int j = 0; j < stateCount; j++) {
beta[time][state] += a[state][j] * b[j][observations[time + 1]] * beta[time + 1][j];
}
beta[time][state] = scaling[time] * beta[time][state];
}
}
}
double getAlpha(int observations) {
double alpha = new double[observations.length][stateCount];
for (int state = 0; state < stateCount; state++) {
alpha[0][state] = pi[state] * b[state][observations[0]]; // Base-case alpha for time = 0
}
for (int time = 1; time < observations.length; time++) {
for (int state = 0; state < stateCount; state++) {
for (int j = 0; j < stateCount; j++) {
alpha[time][state] += alpha[(time - 1)][j] * a[j][state] * b[state][observations[time]]; // FIXME: Explain
}
}
}
return alpha[observations.length - 1];
}
BaumWelch(int stateCount, int observationCount) { // Constructor
this.stateCount = stateCount;
this.observationCount = observationCount;
this.a = new double[stateCount][stateCount];
this.b = new double[stateCount][observationCount];
this.pi = new double[stateCount];
for (int i = 0; i < stateCount; i++) {
for (int j = 0; j < stateCount; j++) {
this.a[i][j] = 1000 + Math.random() * 10000;
}
}
this.a = Util.normalize(a);
for (int i = 0; i < stateCount; i++) {
for (int j = 0; j < observationCount; j++) {
this.b[i][j] = 1000 + Math.random() * 500;
}
}
this.b = Util.normalize(b);
for (int i = 0; i < stateCount; i++) {
this.pi[i] = 1000 + Math.random() * 500;
}
this.pi = Util.normalize(pi);
}
}
java statistics ai machine-learning
$endgroup$
add a comment |
$begingroup$
I am testing my implementation of Baum-Welch algorithm for stock market data. The sequence for the price of a certain stock ends with this:
73 76 76 76 76 75 78 78 76 74 78 78 79 80 83 82 82 81 80 79 78 77 76 77 76 75 77 78 79 80 78 75 75 76 75 76 77 76 75 76 81
Now I try to make my model estimate the last observation:
int obs = { 73, 76, 76, 76, 76, 75, 78, 78, 76, 74, 78, 78, 79, 80, 83, 82, 82,81, 80, 79, 78, 77, 76,77, 76, 75, 77, 78, 79, 80, 78, 75, 75, 76, 75, 76, 77, 76, 75, 76};
BaumWelch model = new BaumWelch(100, 100);
model.run(obs);
double optimumP = 0;
int estimate = -1;
double transitions = Util.normalize(model.getAlpha(obs)); // FIXME: Explain
double nextObservations = Util.normalize(model.getNextObservationProbabilities(transitions));
double maxP = 0;
int tmp = 0;
for (int j = 0; j < nextObservations.length; j++) {
if (nextObservations[j] > maxP) {
maxP = nextObservations[j];
tmp = j;
}
}
if (maxP > optimumP) {
estimate = tmp;
}
System.out.println(estimate);
The output is 75, not very good. I try to estimate the second to last observation with code exactly like above just removing the last (76) and the output becomes 77 which is good enough. I test once again removing the "75" and see if it can estimate it: Output is 75. I remove the last observation again (76) and see what the program estimates. Output is 76.
So it seems close to making predictions in the right direction most of the time. Naturally I will test this much more, but I am looking for bugs regarding correctness of my implementation of Baum-Welch algorithm.
The actual model code is the following.
import java.lang.Math;
class BaumWelch {
int iterationCount = 30; // Arbitrary count of maximum iterations
int stateCount; // Count of hidden states e.g. number of birds
int observationCount; // Count of observations e.g. nu
double a; // Transition matrix
double b; // Observation matrix
double pi; // Initialization vector
void run(int observations) { // FIXME: Explain the Baum-Welch algorithm
double numer, denom, logP, maxP;
double xi = new double[observations.length][stateCount][stateCount];
double alpha = new double[observations.length][stateCount];
double beta = new double[observations.length][stateCount];
double gamma = new double[observations.length][stateCount];
double scaling = new double[observations.length];
maxP = -Double.MAX_VALUE; // FIXME: Explain initialization of logaritmic probability
for(int count=0; count < iterationCount; count++){
logP = 0;
forward(observations, alpha, scaling);
backward(observations, beta, scaling);
for (int time = 0; time < scaling.length; time++) { //FIXME: Explain
logP = scaling[time] != 0 ? logP + Math.log(scaling[time]) : logP;
}
logP = -logP;
if (logP < maxP) break; //improve or break loop
maxP = logP;
for (int time = 0; time < observations.length - 1; time++) { // Σ observations Σ originstates Σ transitionstates
denom = 0; // Calculate a denominator for each transition between possible states
for (int originState = 0; originState < stateCount; originState++) { // ΣΣ
for (int transitionState = 0; transitionState < stateCount; transitionState++) { // ΣΣ
denom += alpha[time][originState] * a[originState][transitionState] * b[transitionState][observations[time + 1]] * beta[(time + 1)][transitionState]; // FIXME: Explain this is count of possible states in eq. 37 http://www.cs.ubc.ca/~murphyk/Bayes/rabiner.pdf
}
}
for (int originState = 0; originState < stateCount; originState++) {
gamma[time][originState] = 0;
for (int transitionState = 0; transitionState < stateCount; ++transitionState) { // divide the count of good states by the count of possible states
xi[time][originState][transitionState] = denom != 0 ? (alpha[time][originState] * a[originState][transitionState] * b[transitionState][observations[time + 1]] * beta[(time + 1)][transitionState]) / denom : 0;
gamma[time][originState] += xi[time][originState][transitionState]; // formula 38 in Rabiner (89)
}
}
}
denom = 0;
for (int state = 0; state < stateCount; state++) { // FIXME: Explain line below. count of possible states
denom += alpha[observations.length - 1][state] * beta[observations.length - 1][state];
}
for (int state = 0; state < stateCount; state++) { // Calculate gamma for the last step. formula 27 in Rabiner(89)
gamma[observations.length - 1][state] = 0.0;
gamma[observations.length - 1][state] += (alpha[observations.length - 1][state] * beta[observations.length - 1][state]) / denom;
}
/* Update estimates */
for (int originState = 0; originState < stateCount; originState++) {
pi[originState] = gamma[0][originState]; //initial vector
for (int transitionState = 0; transitionState < stateCount; transitionState++) {
numer = 0.0;
denom = 0.0;
for (int t = 0; t < observations.length - 1; ++t) {
numer += xi[t][originState][transitionState];
denom += gamma[t][originState];
}
a[originState][transitionState] = denom != 0 ? numer / denom : a[originState][transitionState]; //transition matrix
}
for (int transitionState = 0; transitionState < observationCount; transitionState++) {
numer = 0.0;
denom = 0.0;
for (int time = 0; time < observations.length; time++) {
if (transitionState == observations[time]) {
numer += gamma[time][originState];
}
denom += gamma[time][originState];
}
b[originState][transitionState] = denom != 0 ? numer / denom : b[originState][transitionState]; //observation matrix
}
}
}
}
/* Calculate the distribution (array) of next step observations */
double getNextObservationProbabilities(double pCurrent) {
double returnProb = new double[observationCount];
for (int currentState = 0; currentState < stateCount; currentState++) {
for (int nextState = 0; nextState < stateCount; nextState++) {
for (int observation = 0; observation < observationCount; observation++) {
returnProb[observation] += pCurrent[nextState] * a[nextState][currentState] * b[currentState][observation];
}
}
}
return returnProb;
}
double getSequenceProbability(int observations) { // Calculate the probability of a given sequence "observations"
double returnProb = 0;
double alpha = new double[observations.length][stateCount]; // FIXME: Explain
for (int i = 0; i < stateCount; i++) {
alpha[0][i] = pi[i] * b[i][observations[0]]; // FIXME: Explain
}
for (int time = 0; ++time < observations.length;) { // FIXME: Explain indexing
for (int i = 0; i < stateCount; i++) {
for (int j = 0; j < stateCount; j++) {
alpha[time][i] += alpha[(time - 1)][j] * a[j][i] * b[i][observations[time]]; // FIXME: Explain
}
}
}
for (int i = 0; i < stateCount; i++) {
returnProb += alpha[(alpha.length - 1)][i];
}
return returnProb;
}
void forward(int observations, double alpha, double scaling) { // FIXME: Explain https://en.wikipedia.org/wiki/Forward_algorithm
scaling[0] = 0;
for (int state = 0; ++state < stateCount;) {
alpha[0][state] = pi[state] * b[state][observations[0]];
scaling[0] += alpha[0][state];
}
scaling[0] = scaling[0] != 0 ? 1.0 / scaling[0] : scaling[0];
for (int state = 0; state < stateCount; state++) {
alpha[0][state] = scaling[0] * alpha[0][state];
}
for (int time = 1; time < observations.length; time++) { // Forward in time
scaling[time] = 0;
for (int state = 0; state < stateCount; state++) {
double sum = 0;
alpha[time][state] = 0;
for (int originState = 0; originState < stateCount; originState++) {
sum += alpha[time - 1][originState] * a[originState][state];
}
alpha[time][state] = sum * b[state][observations[time]];
scaling[time] += sum;
}
scaling[time] = scaling[time] != 0 ? scaling[time] = 1.0 / scaling[time] : scaling[time];
for (int i = 0; i < stateCount; i++) {
alpha[time][i] = scaling[time] * alpha[time][i];
}
}
}
void backward(int observations, double beta, double scaling) { // FIXME: Explain https://en.wikipedia.org/wiki/Forward%E2%80%93backward_algorithm
for (int state = 0; state < stateCount; state++) {
beta[(observations.length - 1)][state] = scaling[observations.length - 1]; // Base-case beta for last observation
}
for (int time = observations.length - 2; time >= 0; time--) { // Backward in time
for (int state = 0; state < stateCount; state++) {
beta[time][state] = 0;
for (int j = 0; j < stateCount; j++) {
beta[time][state] += a[state][j] * b[j][observations[time + 1]] * beta[time + 1][j];
}
beta[time][state] = scaling[time] * beta[time][state];
}
}
}
double getAlpha(int observations) {
double alpha = new double[observations.length][stateCount];
for (int state = 0; state < stateCount; state++) {
alpha[0][state] = pi[state] * b[state][observations[0]]; // Base-case alpha for time = 0
}
for (int time = 1; time < observations.length; time++) {
for (int state = 0; state < stateCount; state++) {
for (int j = 0; j < stateCount; j++) {
alpha[time][state] += alpha[(time - 1)][j] * a[j][state] * b[state][observations[time]]; // FIXME: Explain
}
}
}
return alpha[observations.length - 1];
}
BaumWelch(int stateCount, int observationCount) { // Constructor
this.stateCount = stateCount;
this.observationCount = observationCount;
this.a = new double[stateCount][stateCount];
this.b = new double[stateCount][observationCount];
this.pi = new double[stateCount];
for (int i = 0; i < stateCount; i++) {
for (int j = 0; j < stateCount; j++) {
this.a[i][j] = 1000 + Math.random() * 10000;
}
}
this.a = Util.normalize(a);
for (int i = 0; i < stateCount; i++) {
for (int j = 0; j < observationCount; j++) {
this.b[i][j] = 1000 + Math.random() * 500;
}
}
this.b = Util.normalize(b);
for (int i = 0; i < stateCount; i++) {
this.pi[i] = 1000 + Math.random() * 500;
}
this.pi = Util.normalize(pi);
}
}
java statistics ai machine-learning
$endgroup$
add a comment |
$begingroup$
I am testing my implementation of Baum-Welch algorithm for stock market data. The sequence for the price of a certain stock ends with this:
73 76 76 76 76 75 78 78 76 74 78 78 79 80 83 82 82 81 80 79 78 77 76 77 76 75 77 78 79 80 78 75 75 76 75 76 77 76 75 76 81
Now I try to make my model estimate the last observation:
int obs = { 73, 76, 76, 76, 76, 75, 78, 78, 76, 74, 78, 78, 79, 80, 83, 82, 82,81, 80, 79, 78, 77, 76,77, 76, 75, 77, 78, 79, 80, 78, 75, 75, 76, 75, 76, 77, 76, 75, 76};
BaumWelch model = new BaumWelch(100, 100);
model.run(obs);
double optimumP = 0;
int estimate = -1;
double transitions = Util.normalize(model.getAlpha(obs)); // FIXME: Explain
double nextObservations = Util.normalize(model.getNextObservationProbabilities(transitions));
double maxP = 0;
int tmp = 0;
for (int j = 0; j < nextObservations.length; j++) {
if (nextObservations[j] > maxP) {
maxP = nextObservations[j];
tmp = j;
}
}
if (maxP > optimumP) {
estimate = tmp;
}
System.out.println(estimate);
The output is 75, not very good. I try to estimate the second to last observation with code exactly like above just removing the last (76) and the output becomes 77 which is good enough. I test once again removing the "75" and see if it can estimate it: Output is 75. I remove the last observation again (76) and see what the program estimates. Output is 76.
So it seems close to making predictions in the right direction most of the time. Naturally I will test this much more, but I am looking for bugs regarding correctness of my implementation of Baum-Welch algorithm.
The actual model code is the following.
import java.lang.Math;
class BaumWelch {
int iterationCount = 30; // Arbitrary count of maximum iterations
int stateCount; // Count of hidden states e.g. number of birds
int observationCount; // Count of observations e.g. nu
double a; // Transition matrix
double b; // Observation matrix
double pi; // Initialization vector
void run(int observations) { // FIXME: Explain the Baum-Welch algorithm
double numer, denom, logP, maxP;
double xi = new double[observations.length][stateCount][stateCount];
double alpha = new double[observations.length][stateCount];
double beta = new double[observations.length][stateCount];
double gamma = new double[observations.length][stateCount];
double scaling = new double[observations.length];
maxP = -Double.MAX_VALUE; // FIXME: Explain initialization of logaritmic probability
for(int count=0; count < iterationCount; count++){
logP = 0;
forward(observations, alpha, scaling);
backward(observations, beta, scaling);
for (int time = 0; time < scaling.length; time++) { //FIXME: Explain
logP = scaling[time] != 0 ? logP + Math.log(scaling[time]) : logP;
}
logP = -logP;
if (logP < maxP) break; //improve or break loop
maxP = logP;
for (int time = 0; time < observations.length - 1; time++) { // Σ observations Σ originstates Σ transitionstates
denom = 0; // Calculate a denominator for each transition between possible states
for (int originState = 0; originState < stateCount; originState++) { // ΣΣ
for (int transitionState = 0; transitionState < stateCount; transitionState++) { // ΣΣ
denom += alpha[time][originState] * a[originState][transitionState] * b[transitionState][observations[time + 1]] * beta[(time + 1)][transitionState]; // FIXME: Explain this is count of possible states in eq. 37 http://www.cs.ubc.ca/~murphyk/Bayes/rabiner.pdf
}
}
for (int originState = 0; originState < stateCount; originState++) {
gamma[time][originState] = 0;
for (int transitionState = 0; transitionState < stateCount; ++transitionState) { // divide the count of good states by the count of possible states
xi[time][originState][transitionState] = denom != 0 ? (alpha[time][originState] * a[originState][transitionState] * b[transitionState][observations[time + 1]] * beta[(time + 1)][transitionState]) / denom : 0;
gamma[time][originState] += xi[time][originState][transitionState]; // formula 38 in Rabiner (89)
}
}
}
denom = 0;
for (int state = 0; state < stateCount; state++) { // FIXME: Explain line below. count of possible states
denom += alpha[observations.length - 1][state] * beta[observations.length - 1][state];
}
for (int state = 0; state < stateCount; state++) { // Calculate gamma for the last step. formula 27 in Rabiner(89)
gamma[observations.length - 1][state] = 0.0;
gamma[observations.length - 1][state] += (alpha[observations.length - 1][state] * beta[observations.length - 1][state]) / denom;
}
/* Update estimates */
for (int originState = 0; originState < stateCount; originState++) {
pi[originState] = gamma[0][originState]; //initial vector
for (int transitionState = 0; transitionState < stateCount; transitionState++) {
numer = 0.0;
denom = 0.0;
for (int t = 0; t < observations.length - 1; ++t) {
numer += xi[t][originState][transitionState];
denom += gamma[t][originState];
}
a[originState][transitionState] = denom != 0 ? numer / denom : a[originState][transitionState]; //transition matrix
}
for (int transitionState = 0; transitionState < observationCount; transitionState++) {
numer = 0.0;
denom = 0.0;
for (int time = 0; time < observations.length; time++) {
if (transitionState == observations[time]) {
numer += gamma[time][originState];
}
denom += gamma[time][originState];
}
b[originState][transitionState] = denom != 0 ? numer / denom : b[originState][transitionState]; //observation matrix
}
}
}
}
/* Calculate the distribution (array) of next step observations */
double getNextObservationProbabilities(double pCurrent) {
double returnProb = new double[observationCount];
for (int currentState = 0; currentState < stateCount; currentState++) {
for (int nextState = 0; nextState < stateCount; nextState++) {
for (int observation = 0; observation < observationCount; observation++) {
returnProb[observation] += pCurrent[nextState] * a[nextState][currentState] * b[currentState][observation];
}
}
}
return returnProb;
}
double getSequenceProbability(int observations) { // Calculate the probability of a given sequence "observations"
double returnProb = 0;
double alpha = new double[observations.length][stateCount]; // FIXME: Explain
for (int i = 0; i < stateCount; i++) {
alpha[0][i] = pi[i] * b[i][observations[0]]; // FIXME: Explain
}
for (int time = 0; ++time < observations.length;) { // FIXME: Explain indexing
for (int i = 0; i < stateCount; i++) {
for (int j = 0; j < stateCount; j++) {
alpha[time][i] += alpha[(time - 1)][j] * a[j][i] * b[i][observations[time]]; // FIXME: Explain
}
}
}
for (int i = 0; i < stateCount; i++) {
returnProb += alpha[(alpha.length - 1)][i];
}
return returnProb;
}
void forward(int observations, double alpha, double scaling) { // FIXME: Explain https://en.wikipedia.org/wiki/Forward_algorithm
scaling[0] = 0;
for (int state = 0; ++state < stateCount;) {
alpha[0][state] = pi[state] * b[state][observations[0]];
scaling[0] += alpha[0][state];
}
scaling[0] = scaling[0] != 0 ? 1.0 / scaling[0] : scaling[0];
for (int state = 0; state < stateCount; state++) {
alpha[0][state] = scaling[0] * alpha[0][state];
}
for (int time = 1; time < observations.length; time++) { // Forward in time
scaling[time] = 0;
for (int state = 0; state < stateCount; state++) {
double sum = 0;
alpha[time][state] = 0;
for (int originState = 0; originState < stateCount; originState++) {
sum += alpha[time - 1][originState] * a[originState][state];
}
alpha[time][state] = sum * b[state][observations[time]];
scaling[time] += sum;
}
scaling[time] = scaling[time] != 0 ? scaling[time] = 1.0 / scaling[time] : scaling[time];
for (int i = 0; i < stateCount; i++) {
alpha[time][i] = scaling[time] * alpha[time][i];
}
}
}
void backward(int observations, double beta, double scaling) { // FIXME: Explain https://en.wikipedia.org/wiki/Forward%E2%80%93backward_algorithm
for (int state = 0; state < stateCount; state++) {
beta[(observations.length - 1)][state] = scaling[observations.length - 1]; // Base-case beta for last observation
}
for (int time = observations.length - 2; time >= 0; time--) { // Backward in time
for (int state = 0; state < stateCount; state++) {
beta[time][state] = 0;
for (int j = 0; j < stateCount; j++) {
beta[time][state] += a[state][j] * b[j][observations[time + 1]] * beta[time + 1][j];
}
beta[time][state] = scaling[time] * beta[time][state];
}
}
}
double getAlpha(int observations) {
double alpha = new double[observations.length][stateCount];
for (int state = 0; state < stateCount; state++) {
alpha[0][state] = pi[state] * b[state][observations[0]]; // Base-case alpha for time = 0
}
for (int time = 1; time < observations.length; time++) {
for (int state = 0; state < stateCount; state++) {
for (int j = 0; j < stateCount; j++) {
alpha[time][state] += alpha[(time - 1)][j] * a[j][state] * b[state][observations[time]]; // FIXME: Explain
}
}
}
return alpha[observations.length - 1];
}
BaumWelch(int stateCount, int observationCount) { // Constructor
this.stateCount = stateCount;
this.observationCount = observationCount;
this.a = new double[stateCount][stateCount];
this.b = new double[stateCount][observationCount];
this.pi = new double[stateCount];
for (int i = 0; i < stateCount; i++) {
for (int j = 0; j < stateCount; j++) {
this.a[i][j] = 1000 + Math.random() * 10000;
}
}
this.a = Util.normalize(a);
for (int i = 0; i < stateCount; i++) {
for (int j = 0; j < observationCount; j++) {
this.b[i][j] = 1000 + Math.random() * 500;
}
}
this.b = Util.normalize(b);
for (int i = 0; i < stateCount; i++) {
this.pi[i] = 1000 + Math.random() * 500;
}
this.pi = Util.normalize(pi);
}
}
java statistics ai machine-learning
$endgroup$
I am testing my implementation of Baum-Welch algorithm for stock market data. The sequence for the price of a certain stock ends with this:
73 76 76 76 76 75 78 78 76 74 78 78 79 80 83 82 82 81 80 79 78 77 76 77 76 75 77 78 79 80 78 75 75 76 75 76 77 76 75 76 81
Now I try to make my model estimate the last observation:
int obs = { 73, 76, 76, 76, 76, 75, 78, 78, 76, 74, 78, 78, 79, 80, 83, 82, 82,81, 80, 79, 78, 77, 76,77, 76, 75, 77, 78, 79, 80, 78, 75, 75, 76, 75, 76, 77, 76, 75, 76};
BaumWelch model = new BaumWelch(100, 100);
model.run(obs);
double optimumP = 0;
int estimate = -1;
double transitions = Util.normalize(model.getAlpha(obs)); // FIXME: Explain
double nextObservations = Util.normalize(model.getNextObservationProbabilities(transitions));
double maxP = 0;
int tmp = 0;
for (int j = 0; j < nextObservations.length; j++) {
if (nextObservations[j] > maxP) {
maxP = nextObservations[j];
tmp = j;
}
}
if (maxP > optimumP) {
estimate = tmp;
}
System.out.println(estimate);
The output is 75, not very good. I try to estimate the second to last observation with code exactly like above just removing the last (76) and the output becomes 77 which is good enough. I test once again removing the "75" and see if it can estimate it: Output is 75. I remove the last observation again (76) and see what the program estimates. Output is 76.
So it seems close to making predictions in the right direction most of the time. Naturally I will test this much more, but I am looking for bugs regarding correctness of my implementation of Baum-Welch algorithm.
The actual model code is the following.
import java.lang.Math;
class BaumWelch {
int iterationCount = 30; // Arbitrary count of maximum iterations
int stateCount; // Count of hidden states e.g. number of birds
int observationCount; // Count of observations e.g. nu
double a; // Transition matrix
double b; // Observation matrix
double pi; // Initialization vector
void run(int observations) { // FIXME: Explain the Baum-Welch algorithm
double numer, denom, logP, maxP;
double xi = new double[observations.length][stateCount][stateCount];
double alpha = new double[observations.length][stateCount];
double beta = new double[observations.length][stateCount];
double gamma = new double[observations.length][stateCount];
double scaling = new double[observations.length];
maxP = -Double.MAX_VALUE; // FIXME: Explain initialization of logaritmic probability
for(int count=0; count < iterationCount; count++){
logP = 0;
forward(observations, alpha, scaling);
backward(observations, beta, scaling);
for (int time = 0; time < scaling.length; time++) { //FIXME: Explain
logP = scaling[time] != 0 ? logP + Math.log(scaling[time]) : logP;
}
logP = -logP;
if (logP < maxP) break; //improve or break loop
maxP = logP;
for (int time = 0; time < observations.length - 1; time++) { // Σ observations Σ originstates Σ transitionstates
denom = 0; // Calculate a denominator for each transition between possible states
for (int originState = 0; originState < stateCount; originState++) { // ΣΣ
for (int transitionState = 0; transitionState < stateCount; transitionState++) { // ΣΣ
denom += alpha[time][originState] * a[originState][transitionState] * b[transitionState][observations[time + 1]] * beta[(time + 1)][transitionState]; // FIXME: Explain this is count of possible states in eq. 37 http://www.cs.ubc.ca/~murphyk/Bayes/rabiner.pdf
}
}
for (int originState = 0; originState < stateCount; originState++) {
gamma[time][originState] = 0;
for (int transitionState = 0; transitionState < stateCount; ++transitionState) { // divide the count of good states by the count of possible states
xi[time][originState][transitionState] = denom != 0 ? (alpha[time][originState] * a[originState][transitionState] * b[transitionState][observations[time + 1]] * beta[(time + 1)][transitionState]) / denom : 0;
gamma[time][originState] += xi[time][originState][transitionState]; // formula 38 in Rabiner (89)
}
}
}
denom = 0;
for (int state = 0; state < stateCount; state++) { // FIXME: Explain line below. count of possible states
denom += alpha[observations.length - 1][state] * beta[observations.length - 1][state];
}
for (int state = 0; state < stateCount; state++) { // Calculate gamma for the last step. formula 27 in Rabiner(89)
gamma[observations.length - 1][state] = 0.0;
gamma[observations.length - 1][state] += (alpha[observations.length - 1][state] * beta[observations.length - 1][state]) / denom;
}
/* Update estimates */
for (int originState = 0; originState < stateCount; originState++) {
pi[originState] = gamma[0][originState]; //initial vector
for (int transitionState = 0; transitionState < stateCount; transitionState++) {
numer = 0.0;
denom = 0.0;
for (int t = 0; t < observations.length - 1; ++t) {
numer += xi[t][originState][transitionState];
denom += gamma[t][originState];
}
a[originState][transitionState] = denom != 0 ? numer / denom : a[originState][transitionState]; //transition matrix
}
for (int transitionState = 0; transitionState < observationCount; transitionState++) {
numer = 0.0;
denom = 0.0;
for (int time = 0; time < observations.length; time++) {
if (transitionState == observations[time]) {
numer += gamma[time][originState];
}
denom += gamma[time][originState];
}
b[originState][transitionState] = denom != 0 ? numer / denom : b[originState][transitionState]; //observation matrix
}
}
}
}
/* Calculate the distribution (array) of next step observations */
double getNextObservationProbabilities(double pCurrent) {
double returnProb = new double[observationCount];
for (int currentState = 0; currentState < stateCount; currentState++) {
for (int nextState = 0; nextState < stateCount; nextState++) {
for (int observation = 0; observation < observationCount; observation++) {
returnProb[observation] += pCurrent[nextState] * a[nextState][currentState] * b[currentState][observation];
}
}
}
return returnProb;
}
double getSequenceProbability(int observations) { // Calculate the probability of a given sequence "observations"
double returnProb = 0;
double alpha = new double[observations.length][stateCount]; // FIXME: Explain
for (int i = 0; i < stateCount; i++) {
alpha[0][i] = pi[i] * b[i][observations[0]]; // FIXME: Explain
}
for (int time = 0; ++time < observations.length;) { // FIXME: Explain indexing
for (int i = 0; i < stateCount; i++) {
for (int j = 0; j < stateCount; j++) {
alpha[time][i] += alpha[(time - 1)][j] * a[j][i] * b[i][observations[time]]; // FIXME: Explain
}
}
}
for (int i = 0; i < stateCount; i++) {
returnProb += alpha[(alpha.length - 1)][i];
}
return returnProb;
}
void forward(int observations, double alpha, double scaling) { // FIXME: Explain https://en.wikipedia.org/wiki/Forward_algorithm
scaling[0] = 0;
for (int state = 0; ++state < stateCount;) {
alpha[0][state] = pi[state] * b[state][observations[0]];
scaling[0] += alpha[0][state];
}
scaling[0] = scaling[0] != 0 ? 1.0 / scaling[0] : scaling[0];
for (int state = 0; state < stateCount; state++) {
alpha[0][state] = scaling[0] * alpha[0][state];
}
for (int time = 1; time < observations.length; time++) { // Forward in time
scaling[time] = 0;
for (int state = 0; state < stateCount; state++) {
double sum = 0;
alpha[time][state] = 0;
for (int originState = 0; originState < stateCount; originState++) {
sum += alpha[time - 1][originState] * a[originState][state];
}
alpha[time][state] = sum * b[state][observations[time]];
scaling[time] += sum;
}
scaling[time] = scaling[time] != 0 ? scaling[time] = 1.0 / scaling[time] : scaling[time];
for (int i = 0; i < stateCount; i++) {
alpha[time][i] = scaling[time] * alpha[time][i];
}
}
}
void backward(int observations, double beta, double scaling) { // FIXME: Explain https://en.wikipedia.org/wiki/Forward%E2%80%93backward_algorithm
for (int state = 0; state < stateCount; state++) {
beta[(observations.length - 1)][state] = scaling[observations.length - 1]; // Base-case beta for last observation
}
for (int time = observations.length - 2; time >= 0; time--) { // Backward in time
for (int state = 0; state < stateCount; state++) {
beta[time][state] = 0;
for (int j = 0; j < stateCount; j++) {
beta[time][state] += a[state][j] * b[j][observations[time + 1]] * beta[time + 1][j];
}
beta[time][state] = scaling[time] * beta[time][state];
}
}
}
double getAlpha(int observations) {
double alpha = new double[observations.length][stateCount];
for (int state = 0; state < stateCount; state++) {
alpha[0][state] = pi[state] * b[state][observations[0]]; // Base-case alpha for time = 0
}
for (int time = 1; time < observations.length; time++) {
for (int state = 0; state < stateCount; state++) {
for (int j = 0; j < stateCount; j++) {
alpha[time][state] += alpha[(time - 1)][j] * a[j][state] * b[state][observations[time]]; // FIXME: Explain
}
}
}
return alpha[observations.length - 1];
}
BaumWelch(int stateCount, int observationCount) { // Constructor
this.stateCount = stateCount;
this.observationCount = observationCount;
this.a = new double[stateCount][stateCount];
this.b = new double[stateCount][observationCount];
this.pi = new double[stateCount];
for (int i = 0; i < stateCount; i++) {
for (int j = 0; j < stateCount; j++) {
this.a[i][j] = 1000 + Math.random() * 10000;
}
}
this.a = Util.normalize(a);
for (int i = 0; i < stateCount; i++) {
for (int j = 0; j < observationCount; j++) {
this.b[i][j] = 1000 + Math.random() * 500;
}
}
this.b = Util.normalize(b);
for (int i = 0; i < stateCount; i++) {
this.pi[i] = 1000 + Math.random() * 500;
}
this.pi = Util.normalize(pi);
}
}
java statistics ai machine-learning
java statistics ai machine-learning
asked 9 mins ago
Niklas RosencrantzNiklas Rosencrantz
641826
641826
add a comment |
add a comment |
0
active
oldest
votes
Your Answer
StackExchange.ifUsing("editor", function () {
return StackExchange.using("mathjaxEditing", function () {
StackExchange.MarkdownEditor.creationCallbacks.add(function (editor, postfix) {
StackExchange.mathjaxEditing.prepareWmdForMathJax(editor, postfix, [["\$", "\$"]]);
});
});
}, "mathjax-editing");
StackExchange.ifUsing("editor", function () {
StackExchange.using("externalEditor", function () {
StackExchange.using("snippets", function () {
StackExchange.snippets.init();
});
});
}, "code-snippets");
StackExchange.ready(function() {
var channelOptions = {
tags: "".split(" "),
id: "196"
};
initTagRenderer("".split(" "), "".split(" "), channelOptions);
StackExchange.using("externalEditor", function() {
// Have to fire editor after snippets, if snippets enabled
if (StackExchange.settings.snippets.snippetsEnabled) {
StackExchange.using("snippets", function() {
createEditor();
});
}
else {
createEditor();
}
});
function createEditor() {
StackExchange.prepareEditor({
heartbeatType: 'answer',
autoActivateHeartbeat: false,
convertImagesToLinks: false,
noModals: true,
showLowRepImageUploadWarning: true,
reputationToPostImages: null,
bindNavPrevention: true,
postfix: "",
imageUploader: {
brandingHtml: "Powered by u003ca class="icon-imgur-white" href="https://imgur.com/"u003eu003c/au003e",
contentPolicyHtml: "User contributions licensed under u003ca href="https://creativecommons.org/licenses/by-sa/3.0/"u003ecc by-sa 3.0 with attribution requiredu003c/au003e u003ca href="https://stackoverflow.com/legal/content-policy"u003e(content policy)u003c/au003e",
allowUrls: true
},
onDemand: true,
discardSelector: ".discard-answer"
,immediatelyShowMarkdownHelp:true
});
}
});
Sign up or log in
StackExchange.ready(function () {
StackExchange.helpers.onClickDraftSave('#login-link');
});
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Post as a guest
Required, but never shown
StackExchange.ready(
function () {
StackExchange.openid.initPostLogin('.new-post-login', 'https%3a%2f%2fcodereview.stackexchange.com%2fquestions%2f212257%2fstock-price-predictions-using-baum-welch-algorithm%23new-answer', 'question_page');
}
);
Post as a guest
Required, but never shown
0
active
oldest
votes
0
active
oldest
votes
active
oldest
votes
active
oldest
votes
Thanks for contributing an answer to Code Review Stack Exchange!
- Please be sure to answer the question. Provide details and share your research!
But avoid …
- Asking for help, clarification, or responding to other answers.
- Making statements based on opinion; back them up with references or personal experience.
Use MathJax to format equations. MathJax reference.
To learn more, see our tips on writing great answers.
Sign up or log in
StackExchange.ready(function () {
StackExchange.helpers.onClickDraftSave('#login-link');
});
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Post as a guest
Required, but never shown
StackExchange.ready(
function () {
StackExchange.openid.initPostLogin('.new-post-login', 'https%3a%2f%2fcodereview.stackexchange.com%2fquestions%2f212257%2fstock-price-predictions-using-baum-welch-algorithm%23new-answer', 'question_page');
}
);
Post as a guest
Required, but never shown
Sign up or log in
StackExchange.ready(function () {
StackExchange.helpers.onClickDraftSave('#login-link');
});
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Post as a guest
Required, but never shown
Sign up or log in
StackExchange.ready(function () {
StackExchange.helpers.onClickDraftSave('#login-link');
});
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Post as a guest
Required, but never shown
Sign up or log in
StackExchange.ready(function () {
StackExchange.helpers.onClickDraftSave('#login-link');
});
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Sign up using Google
Sign up using Facebook
Sign up using Email and Password
Post as a guest
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown
Required, but never shown