001/******************************************************************************* 002 * Copyright (c) 2016 Diamond Light Source Ltd. and others. 003 * All rights reserved. This program and the accompanying materials 004 * are made available under the terms of the Eclipse Public License v1.0 005 * which accompanies this distribution, and is available at 006 * http://www.eclipse.org/legal/epl-v10.html 007 * 008 * Contributors: 009 * Diamond Light Source Ltd - initial API and implementation 010 *******************************************************************************/ 011package org.eclipse.january.dataset; 012 013import java.lang.reflect.Array; 014import java.util.ArrayList; 015import java.util.List; 016 017public class ShapeUtils { 018 019 private ShapeUtils() { 020 } 021 022 /** 023 * Calculate total number of items in given shape 024 * @param shape 025 * @return size 026 */ 027 public static long calcLongSize(final int[] shape) { 028 if (shape == null) { // special case of null-shaped 029 return 0; 030 } 031 032 final int rank = shape.length; 033 if (rank == 0) { // special case of zero-rank shape 034 return 1; 035 } 036 037 double dsize = 1.0; 038 for (int i = 0; i < rank; i++) { 039 // make sure the indexes isn't zero or negative 040 if (shape[i] == 0) { 041 return 0; 042 } else if (shape[i] < 0) { 043 throw new IllegalArgumentException(String.format( 044 "The %d-th is %d which is not allowed as it is negative", i, shape[i])); 045 } 046 047 dsize *= shape[i]; 048 } 049 050 // check to see if the size is larger than an integer, i.e. we can't allocate it 051 if (dsize > Long.MAX_VALUE) { 052 throw new IllegalArgumentException("Size of the dataset is too large to allocate"); 053 } 054 return (long) dsize; 055 } 056 057 /** 058 * Calculate total number of items in given shape 059 * @param shape 060 * @return size 061 */ 062 public static int calcSize(final int[] shape) { 063 long lsize = calcLongSize(shape); 064 065 // check to see if the size is larger than an integer, i.e. we can't allocate it 066 if (lsize > Integer.MAX_VALUE) { 067 throw new IllegalArgumentException("Size of the dataset is too large to allocate"); 068 } 069 return (int) lsize; 070 } 071 072 /** 073 * Check if shapes are broadcast compatible 074 * 075 * @param ashape 076 * @param bshape 077 * @return true if they are compatible 078 */ 079 public static boolean areShapesBroadcastCompatible(final int[] ashape, final int[] bshape) { 080 if (ashape == null || bshape == null) { 081 return ashape == bshape; 082 } 083 084 if (ashape.length < bshape.length) { 085 return areShapesBroadcastCompatible(bshape, ashape); 086 } 087 088 for (int a = ashape.length - bshape.length, b = 0; a < ashape.length && b < bshape.length; a++, b++) { 089 if (ashape[a] != bshape[b] && ashape[a] != 1 && bshape[b] != 1) { 090 return false; 091 } 092 } 093 094 return true; 095 } 096 097 /** 098 * Check if shapes are compatible, ignoring extra axes of length 1 099 * 100 * @param ashape 101 * @param bshape 102 * @return true if they are compatible 103 */ 104 public static boolean areShapesCompatible(final int[] ashape, final int[] bshape) { 105 if (ashape == null || bshape == null) { 106 return ashape == bshape; 107 } 108 109 List<Integer> alist = new ArrayList<Integer>(); 110 111 for (int a : ashape) { 112 if (a > 1) alist.add(a); 113 } 114 115 final int imax = alist.size(); 116 int i = 0; 117 for (int b : bshape) { 118 if (b == 1) 119 continue; 120 if (i >= imax || b != alist.get(i++)) 121 return false; 122 } 123 124 return i == imax; 125 } 126 127 /** 128 * Check if shapes are compatible but skip axis 129 * 130 * @param ashape 131 * @param bshape 132 * @param axis 133 * @return true if they are compatible 134 */ 135 public static boolean areShapesCompatible(final int[] ashape, final int[] bshape, final int axis) { 136 if (ashape == null || bshape == null) { 137 return ashape == bshape; 138 } 139 140 if (ashape.length != bshape.length) { 141 return false; 142 } 143 144 final int rank = ashape.length; 145 for (int i = 0; i < rank; i++) { 146 if (i != axis && ashape[i] != bshape[i]) { 147 return false; 148 } 149 } 150 return true; 151 } 152 153 /** 154 * Remove dimensions of 1 in given shape - from both ends only, if true 155 * 156 * @param oshape 157 * @param onlyFromEnds 158 * @return newly squeezed shape (or original if unsqueezed) 159 */ 160 public static int[] squeezeShape(final int[] oshape, boolean onlyFromEnds) { 161 int unitDims = 0; 162 int rank = oshape.length; 163 int start = 0; 164 165 if (onlyFromEnds) { 166 int i = rank - 1; 167 for (; i >= 0; i--) { 168 if (oshape[i] == 1) { 169 unitDims++; 170 } else { 171 break; 172 } 173 } 174 for (int j = 0; j <= i; j++) { 175 if (oshape[j] == 1) { 176 unitDims++; 177 } else { 178 start = j; 179 break; 180 } 181 } 182 } else { 183 for (int i = 0; i < rank; i++) { 184 if (oshape[i] == 1) { 185 unitDims++; 186 } 187 } 188 } 189 190 if (unitDims == 0) { 191 return oshape; 192 } 193 194 int[] newDims = new int[rank - unitDims]; 195 if (unitDims == rank) 196 return newDims; // zero-rank dataset 197 198 if (onlyFromEnds) { 199 rank = newDims.length; 200 for (int i = 0; i < rank; i++) { 201 newDims[i] = oshape[i+start]; 202 } 203 } else { 204 int j = 0; 205 for (int i = 0; i < rank; i++) { 206 if (oshape[i] > 1) { 207 newDims[j++] = oshape[i]; 208 if (j >= newDims.length) 209 break; 210 } 211 } 212 } 213 214 return newDims; 215 } 216 217 /** 218 * Remove dimension of 1 in given shape 219 * 220 * @param oshape 221 * @param axis 222 * @return newly squeezed shape 223 */ 224 public static int[] squeezeShape(final int[] oshape, int axis) { 225 if (oshape == null) { 226 return null; 227 } 228 229 final int rank = oshape.length; 230 if (rank == 0) { 231 return new int[0]; 232 } 233 if (axis < 0) { 234 axis += rank; 235 } 236 if (axis < 0 || axis >= rank) { 237 throw new IllegalArgumentException("Axis argument is outside allowed range"); 238 } 239 int[] nshape = new int[rank-1]; 240 for (int i = 0; i < axis; i++) { 241 nshape[i] = oshape[i]; 242 } 243 for (int i = axis+1; i < rank; i++) { 244 nshape[i-1] = oshape[i]; 245 } 246 return nshape; 247 } 248 249 /** 250 * Get shape from object (array or list supported) 251 * @param obj 252 * @return shape can be null if obj is null 253 */ 254 public static int[] getShapeFromObject(final Object obj) { 255 if (obj == null) { 256 return null; 257 } 258 259 ArrayList<Integer> lshape = new ArrayList<Integer>(); 260 getShapeFromObj(lshape, obj, 0); 261 262 final int rank = lshape.size(); 263 final int[] shape = new int[rank]; 264 for (int i = 0; i < rank; i++) { 265 shape[i] = lshape.get(i); 266 } 267 268 return shape; 269 } 270 271 /** 272 * Get shape from object 273 * @param ldims 274 * @param obj 275 * @param depth 276 * @return true if there is a possibility of differing lengths 277 */ 278 private static boolean getShapeFromObj(final ArrayList<Integer> ldims, Object obj, int depth) { 279 if (obj == null) 280 return true; 281 282 if (obj instanceof List<?>) { 283 List<?> jl = (List<?>) obj; 284 int l = jl.size(); 285 updateShape(ldims, depth, l); 286 for (int i = 0; i < l; i++) { 287 Object lo = jl.get(i); 288 if (!getShapeFromObj(ldims, lo, depth + 1)) { 289 break; 290 } 291 } 292 return true; 293 } 294 Class<? extends Object> ca = obj.getClass().getComponentType(); 295 if (ca != null) { 296 final int l = Array.getLength(obj); 297 updateShape(ldims, depth, l); 298 if (DTypeUtils.isClassSupportedAsElement(ca)) { 299 return true; 300 } 301 for (int i = 0; i < l; i++) { 302 Object lo = Array.get(obj, i); 303 if (!getShapeFromObj(ldims, lo, depth + 1)) { 304 break; 305 } 306 } 307 return true; 308 } else if (obj instanceof IDataset) { 309 int[] s = ((IDataset) obj).getShape(); 310 for (int i = 0; i < s.length; i++) { 311 updateShape(ldims, depth++, s[i]); 312 } 313 return true; 314 } else { 315 return false; // not an array of any type 316 } 317 } 318 319 private static void updateShape(final ArrayList<Integer> ldims, final int depth, final int l) { 320 if (depth >= ldims.size()) { 321 ldims.add(l); 322 } else if (l > ldims.get(depth)) { 323 ldims.set(depth, l); 324 } 325 } 326 327 /** 328 * Get n-D position from given index 329 * @param n index 330 * @param shape 331 * @return n-D position 332 */ 333 public static int[] getNDPositionFromShape(int n, int[] shape) { 334 if (shape == null) { 335 return null; 336 } 337 338 int rank = shape.length; 339 if (rank == 0) { 340 return new int[0]; 341 } 342 343 if (rank == 1) { 344 return new int[] { n }; 345 } 346 347 int[] output = new int[rank]; 348 for (rank--; rank > 0; rank--) { 349 output[rank] = n % shape[rank]; 350 n /= shape[rank]; 351 } 352 output[0] = n; 353 354 return output; 355 } 356 357 /** 358 * Get flattened view index of given position 359 * @param shape 360 * @param pos 361 * the integer array specifying the n-D position 362 * @return the index on the flattened dataset 363 */ 364 public static int getFlat1DIndex(final int[] shape, final int[] pos) { 365 final int imax = pos.length; 366 if (imax == 0) { 367 return 0; 368 } 369 370 return AbstractDataset.get1DIndexFromShape(shape, pos); 371 } 372 373 /** 374 * This function takes a dataset and checks its shape against another dataset. If they are both of the same size, 375 * then this returns with no error, if there is a problem, then an error is thrown. 376 * 377 * @param g 378 * The first dataset to be compared 379 * @param h 380 * The second dataset to be compared 381 * @throws IllegalArgumentException 382 * This will be thrown if there is a problem with the compatibility 383 */ 384 public static void checkCompatibility(final ILazyDataset g, final ILazyDataset h) throws IllegalArgumentException { 385 if (!areShapesCompatible(g.getShape(), h.getShape())) { 386 throw new IllegalArgumentException("Shapes do not match"); 387 } 388 } 389 390 /** 391 * Check that axis is in range [-rank,rank) 392 * 393 * @param rank 394 * @param axis 395 * @return sanitized axis in range [0, rank) 396 * @since 2.1 397 */ 398 public static int checkAxis(int rank, int axis) { 399 if (axis < 0) { 400 axis += rank; 401 } 402 403 if (axis < 0 || axis >= rank) { 404 throw new IllegalArgumentException("Axis " + axis + " given is out of range [0, " + rank + ")"); 405 } 406 return axis; 407 } 408}