根据权重进行排序,结果为排序后的索引。
限制:
1.入参个数必须大于1
2.所有参数必须大于0(小于等于0的权重无意义,sortByWeightAndRandom方法会将小于等于0的放最后进行随机排序)
代码
import java.util.ArrayList; import java.util.Random; /** * 随机排序或者按权重排序,返回排序后的索引. * * @author Laeni */ public final class AliasMethodKit { private static final Random RANDOM = new Random(); /** * 根据权重进行排序. * * @param weights 需要排序的元素权重. * @return 返回排序后元素的索引 */ public static int[] sortByWeight(int[] weights) { if (weights.length == 0) { return new int[0]; } // 将权重值依次放到 0 开始的数轴上: [数轴起始位置, 数轴结束位置, 原始值所在的索引位置], 起始位置与结束位置的差即为原始权重值 // 如 - [3, 5, 2] => [[0, 3, 0], [3, 8, 1], [8, 10, 2]] int[][] xx = new int[weights.length][]; for (int i = 0; i < weights.length; i++) { if (weights[i] <= 0) { throw new IllegalArgumentException("权重值必须大于0,如果需要排序包含小于等于0的权重列表,请使用 sortByWeightAndRandom()"); } if (i == 0) { xx[0] = new int[]{0, weights[i], i}; } else { final int start = xx[i-1][1]; xx[i] = new int[]{start, start + weights[i], i}; } } final int[] ok = new int[weights.length]; // 每次生成一个索引 for (int i = 0; i < weights.length; i++) { final int random = RANDOM.nextInt(xx[xx.length - 1 - i][1]); for (int j = 0; j < xx.length - i; j++) { // 如果某一段被选中之后,要将其从数轴中去除,并且其后的数据要依次往前移动,使得数轴不间断 if (random >= xx[j][0] && random < xx[j][1]) { // 记录选中段的原始索引 ok[i] = xx[j][2]; // 将后面的数据往前移动 for (int k = j; k < xx.length - i - 1; k++) { xx[k][1] = xx[k + 1][1] - (xx[k + 1][0] - xx[k][0]); xx[k][2] = xx[k + 1][2]; xx[k + 1][0] = xx[k][1]; } break; } } } return ok; } /** * 随机进行排序. * * @param weights 需要排序的元素索引. */ public static int[] sortByRandom(int[] weights) { // 提取索引 int[] ins = new int[weights.length]; for (int i = 0; i < weights.length; i++) { ins[i] = i; } // 随机排序 randomArray(ins); return ins; } /** * 根据权重进行排序,返回权重对应的索引. * 对于权重小于等于0的将放到最后进行随机排序 * * @param weights 原始数据对应的权重 * @return 排序后的索引 */ public static int[] sortByWeightAndRandom(int[] weights) { if (weights.length == 0) { return new int[0]; } /// region 分离大于0和非大于0的元素 // 大于0的元素(子元素为长度为2的定长元素,第一个元素是原始权重值,第二个元素为该元素所在的原始索引) final ArrayList<Integer[]> moreWeights = new ArrayList<>(weights.length); // 小于等于0的元素(子元素为长度为2的定长元素,第一个元素是原始权重值,第二个元素为该元素所在的原始索引) final ArrayList<Integer[]> lessWeights = new ArrayList<>(weights.length); for (int i = 0; i < weights.length; i++) { int weight = weights[i]; final Integer[] item = new Integer[]{weight, i}; if (weight > 0) { moreWeights.add(item); } else { lessWeights.add(item); } } /// endregion // 最终排序结果 int[] ok = new int[weights.length]; // 排序大于0的元素 if (moreWeights.size() > 0) { final int[] int1 = new int[moreWeights.size()]; for (int i = 0; i < moreWeights.size(); i++) { int1[i] = moreWeights.get(i)[0]; } final int[] sort = sortByWeight(int1); for (int i = 0; i < sort.length; i++) { ok[i] = moreWeights.get(sort[i])[1]; } } // 排序小于等于0的元素 if (lessWeights.size() > 0) { final int[] int2 = new int[lessWeights.size()]; for (int i = 0; i < lessWeights.size(); i++) { int2[i] = lessWeights.get(i)[0]; } final int[] sort = sortByRandom(int2); for (int i = 0; i < sort.length; i++) { ok[i + moreWeights.size()] = lessWeights.get(sort[i])[1]; } } return ok; } /** * 随机排序. * * @param vs 需要进行排序的数组. */ private static void randomArray(int[] vs) { for (int i = 0; i < vs.length; i++) { // 这里已经排过的不能再排,否则概率不正确 int p = RANDOM.nextInt(vs.length - i); int tmp = vs[i]; vs[i] = vs[p + i]; vs[p + i] = tmp; } } }