001/*- 002 * Copyright 2017 Diamond Light Source Ltd. 003 * 004 * All rights reserved. This program and the accompanying materials 005 * are made available under the terms of the Eclipse Public License v1.0 006 * which accompanies this distribution, and is available at 007 * http://www.eclipse.org/legal/epl-v10.html 008 */ 009 010package org.eclipse.january.dataset; 011 012import java.util.Arrays; 013import java.util.List; 014 015/** 016 * Class to run over a pair of datasets in parallel with NumPy broadcasting to promote shapes 017 * which have lower rank and outputs to a third dataset 018 * @since 2.1 019 */ 020public class BooleanBroadcastIterator extends BooleanIteratorBase { 021 private int[] cShape; 022 private int[] cStride; 023 024 private final int[] cDelta; 025 private final int cStep; 026 private int cMax; 027 private int cStart; 028 029 /** 030 * Construct a boolean iterator that stops at every position in the choice dataset where its value matches 031 * the given boolean 032 * @param v boolean value 033 * @param a primary dataset 034 * @param c choice dataset 035 * @param o output dataset, can be null 036 * @param createIfNull if true create the output dataset if that is null 037 */ 038 public BooleanBroadcastIterator(boolean v, Dataset a, Dataset c, Dataset o, boolean createIfNull) { 039 super(v, a, c, o); 040 List<int[]> fullShapes = BroadcastUtils.broadcastShapes(a.getShapeRef(), c.getShapeRef(), o == null ? null : o.getShapeRef()); 041 042 maxShape = fullShapes.remove(0); 043 044 oStride = null; 045 if (o != null && !Arrays.equals(maxShape, o.getShapeRef())) { 046 throw new IllegalArgumentException("Output does not match broadcasted shape"); 047 } 048 aShape = fullShapes.remove(0); 049 cShape = fullShapes.remove(0); 050 051 int rank = maxShape.length; 052 endrank = rank - 1; 053 054 aDataset = a.reshape(aShape); 055 cDataset = c.reshape(cShape); 056 aStride = BroadcastUtils.createBroadcastStrides(aDataset, maxShape); 057 cStride = BroadcastUtils.createBroadcastStrides(cDataset, maxShape); 058 if (outputA) { 059 oStride = aStride; 060 oDelta = null; 061 oStep = 0; 062 } else if (o != null) { 063 oStride = BroadcastUtils.createBroadcastStrides(o, maxShape); 064 oDelta = new int[rank]; 065 oStep = o.getElementsPerItem(); 066 } else if (createIfNull) { 067 oDataset = BroadcastUtils.createDataset(a, c, maxShape); 068 oStride = BroadcastUtils.createBroadcastStrides(oDataset, maxShape); 069 oDelta = new int[rank]; 070 oStep = oDataset.getElementsPerItem(); 071 } else { 072 oDelta = null; 073 oStep = 0; 074 } 075 076 pos = new int[rank]; 077 aDelta = new int[rank]; 078 cDelta = new int[rank]; 079 cStep = cDataset.getElementsPerItem(); 080 for (int j = endrank; j >= 0; j--) { 081 aDelta[j] = aStride[j] * aShape[j]; 082 cDelta[j] = cStride[j] * cShape[j]; 083 if (oDelta != null) { 084 oDelta[j] = oStride[j] * maxShape[j]; 085 } 086 } 087 aStart = aDataset.getOffset(); 088 cStart = cDataset.getOffset(); 089 aMax = endrank < 0 ? aStep + aStart: Integer.MIN_VALUE; 090 cMax = endrank < 0 ? cStep + cStart: Integer.MIN_VALUE; 091 oStart = oDelta == null ? 0 : oDataset.getOffset(); 092 reset(); 093 } 094 095 @Override 096 public boolean hasNext() { 097 do { 098 int j = endrank; 099 for (; j >= 0; j--) { 100 pos[j]++; 101 index += aStride[j]; 102 cIndex += cStride[j]; 103 if (oDelta != null) { 104 oIndex += oStride[j]; 105 } 106 if (pos[j] >= maxShape[j]) { 107 pos[j] = 0; 108 index -= aDelta[j]; // reset these dimensions 109 cIndex -= cDelta[j]; 110 if (oDelta != null) { 111 oIndex -= oDelta[j]; 112 } 113 } else { 114 break; 115 } 116 } 117 if (j == -1) { 118 if (endrank >= 0) { 119 return false; 120 } 121 index += aStep; 122 cIndex += cStep; 123 if (oDelta != null) { 124 oIndex += oStep; 125 } 126 } 127 if (outputA) { 128 oIndex = index; 129 } 130 131 if (index == aMax || cIndex == cMax) { 132 return false; 133 } 134 } while (cDataset.getElementBooleanAbs(cIndex) != value); 135 136 return true; 137 } 138 139 /** 140 * @return shape of first broadcasted dataset 141 */ 142 public int[] getFirstShape() { 143 return aShape; 144 } 145 146 /** 147 * @return shape of second broadcasted dataset 148 */ 149 public int[] getMaskShape() { 150 return cShape; 151 } 152 153 @Override 154 public void reset() { 155 for (int i = 0; i <= endrank; i++) { 156 pos[i] = 0; 157 } 158 159 if (endrank >= 0) { 160 pos[endrank] = -1; 161 index = aStart - aStride[endrank]; 162 cIndex = cStart - cStride[endrank]; 163 oIndex = oStart - (oStride == null ? 0 : oStride[endrank]); 164 } else { 165 index = aStart - aStep; 166 cIndex = cStart - cStep; 167 oIndex = oStart - oStep; 168 } 169 } 170}