001/*
002 *  This file is part of the Jikes RVM project (http://jikesrvm.org).
003 *
004 *  This file is licensed to You under the Eclipse Public License (EPL);
005 *  You may not use this file except in compliance with the License. You
006 *  may obtain a copy of the License at
007 *
008 *      http://www.opensource.org/licenses/eclipse-1.0.php
009 *
010 *  See the COPYRIGHT.txt file distributed with this work for information
011 *  regarding copyright ownership.
012 */
013package org.jikesrvm.compilers.opt.ir;
014
015import java.util.Enumeration;
016
017import org.jikesrvm.VM;
018import org.jikesrvm.compilers.opt.OptimizingCompilerException;
019import org.jikesrvm.compilers.opt.ir.operand.BranchProfileOperand;
020
021/**
022 * Used to iterate over the branch targets (including the fall through edge)
023 * and associated probabilites of a basic block.
024 * Takes into account the ordering of branch instructions when
025 * computing the edge weights such that the total target weight will always
026 * be equal to 1.0 (flow in == flow out).
027 */
028public final class WeightedBranchTargets {
029  private BasicBlock[] targets;
030  private float[] weights;
031  private int cur;
032  private int max;
033
034  public void reset() {
035    cur = 0;
036  }
037
038  public boolean hasMoreElements() {
039    return cur < max;
040  }
041
042  public void advance() {
043    cur++;
044  }
045
046  public BasicBlock curBlock() {
047    return targets[cur];
048  }
049
050  public float curWeight() {
051    return weights[cur];
052  }
053
054  public WeightedBranchTargets(BasicBlock bb) {
055    targets = new BasicBlock[3];
056    weights = new float[3];
057    cur = 0;
058    max = 0;
059
060    float prob = 1f;
061    for (Enumeration<Instruction> ie = bb.enumerateBranchInstructions(); ie.hasMoreElements();) {
062      Instruction s = ie.nextElement();
063      if (IfCmp.conforms(s)) {
064        BasicBlock target = IfCmp.getTarget(s).target.getBasicBlock();
065        BranchProfileOperand prof = IfCmp.getBranchProfile(s);
066        float taken = prob * prof.takenProbability;
067        prob = prob * (1f - prof.takenProbability);
068        addEdge(target, taken);
069      } else if (Goto.conforms(s)) {
070        BasicBlock target = Goto.getTarget(s).target.getBasicBlock();
071        addEdge(target, prob);
072      } else if (InlineGuard.conforms(s)) {
073        BasicBlock target = InlineGuard.getTarget(s).target.getBasicBlock();
074        BranchProfileOperand prof = InlineGuard.getBranchProfile(s);
075        float taken = prob * prof.takenProbability;
076        prob = prob * (1f - prof.takenProbability);
077        addEdge(target, taken);
078      } else if (IfCmp2.conforms(s)) {
079        BasicBlock target = IfCmp2.getTarget1(s).target.getBasicBlock();
080        BranchProfileOperand prof = IfCmp2.getBranchProfile1(s);
081        float taken = prob * prof.takenProbability;
082        prob = prob * (1f - prof.takenProbability);
083        addEdge(target, taken);
084        target = IfCmp2.getTarget2(s).target.getBasicBlock();
085        prof = IfCmp2.getBranchProfile2(s);
086        taken = prob * prof.takenProbability;
087        prob = prob * (1f - prof.takenProbability);
088        addEdge(target, taken);
089      } else if (TableSwitch.conforms(s)) {
090        int lowLimit = TableSwitch.getLow(s).value;
091        int highLimit = TableSwitch.getHigh(s).value;
092        int number = highLimit - lowLimit + 1;
093        float total = 0f;
094        for (int i = 0; i < number; i++) {
095          BasicBlock target = TableSwitch.getTarget(s, i).target.getBasicBlock();
096          BranchProfileOperand prof = TableSwitch.getBranchProfile(s, i);
097          float taken = prob * prof.takenProbability;
098          total += prof.takenProbability;
099          addEdge(target, taken);
100        }
101        BasicBlock target = TableSwitch.getDefault(s).target.getBasicBlock();
102        BranchProfileOperand prof = TableSwitch.getDefaultBranchProfile(s);
103        float taken = prob * prof.takenProbability;
104        total += prof.takenProbability;
105        if (VM.VerifyAssertions && !epsilon(total, 1f)) {
106          VM.sysFail("Total outflow (" + total + ") does not sum to 1 for: " + s);
107        }
108        addEdge(target, taken);
109      } else if (LowTableSwitch.conforms(s)) {
110        int number = LowTableSwitch.getNumberOfTargets(s);
111        float total = 0f;
112        for (int i = 0; i < number; i++) {
113          BasicBlock target = LowTableSwitch.getTarget(s, i).target.getBasicBlock();
114          BranchProfileOperand prof = LowTableSwitch.getBranchProfile(s, i);
115          float taken = prob * prof.takenProbability;
116          total += prof.takenProbability;
117          addEdge(target, taken);
118        }
119        if (VM.VerifyAssertions && !epsilon(total, 1f)) {
120          VM.sysFail("Total outflow (" + total + ") does not sum to 1 for: " + s);
121        }
122      } else if (LookupSwitch.conforms(s)) {
123        int number = LookupSwitch.getNumberOfTargets(s);
124        float total = 0f;
125        for (int i = 0; i < number; i++) {
126          BasicBlock target = LookupSwitch.getTarget(s, i).target.getBasicBlock();
127          BranchProfileOperand prof = LookupSwitch.getBranchProfile(s, i);
128          float taken = prob * prof.takenProbability;
129          total += prof.takenProbability;
130          addEdge(target, taken);
131        }
132        BasicBlock target = LookupSwitch.getDefault(s).target.getBasicBlock();
133        BranchProfileOperand prof = LookupSwitch.getDefaultBranchProfile(s);
134        float taken = prob * prof.takenProbability;
135        total += prof.takenProbability;
136        if (VM.VerifyAssertions && !epsilon(total, 1f)) {
137          VM.sysFail("Total outflow (" + total + ") does not sum to 1 for: " + s);
138        }
139        addEdge(target, taken);
140      } else {
141        throw new OptimizingCompilerException("TODO " + s + "\n");
142      }
143    }
144    BasicBlock ft = bb.getFallThroughBlock();
145    if (ft != null) addEdge(ft, prob);
146  }
147
148  private void addEdge(BasicBlock target, float weight) {
149    if (max == targets.length) {
150      BasicBlock[] tmp = new BasicBlock[targets.length << 1];
151      for (int i = 0; i < targets.length; i++) {
152        tmp[i] = targets[i];
153      }
154      targets = tmp;
155      float[] tmp2 = new float[weights.length << 1];
156      for (int i = 0; i < weights.length; i++) {
157        tmp2[i] = weights[i];
158      }
159      weights = tmp2;
160    }
161    targets[max] = target;
162    weights[max] = weight;
163    max++;
164  }
165
166  private boolean epsilon(float a, float b) {
167    return Math.abs(a - b) < 0.1;
168  }
169}