001/******************************************************************************* 002 * Copyright (c) 2016 Diamond Light Source Ltd. and others. 003 * All rights reserved. This program and the accompanying materials 004 * are made available under the terms of the Eclipse Public License v1.0 005 * which accompanies this distribution, and is available at 006 * http://www.eclipse.org/legal/epl-v10.html 007 * 008 * Contributors: 009 * Diamond Light Source Ltd - initial API and implementation 010 *******************************************************************************/ 011package org.eclipse.january.dataset; 012 013import java.lang.reflect.Array; 014import java.util.ArrayList; 015import java.util.Arrays; 016import java.util.Collection; 017import java.util.List; 018import java.util.SortedSet; 019import java.util.TreeSet; 020 021public class ShapeUtils { 022 023 private ShapeUtils() { 024 } 025 026 /** 027 * Calculate total number of items in given shape 028 * @param shape dataset shape 029 * @return size 030 */ 031 public static long calcLongSize(final int[] shape) { 032 if (shape == null) { // special case of null-shaped 033 return 0; 034 } 035 036 final int rank = shape.length; 037 if (rank == 0) { // special case of zero-rank shape 038 return 1; 039 } 040 041 double dsize = 1.0; 042 for (int i = 0; i < rank; i++) { 043 // make sure the indexes isn't zero or negative 044 if (shape[i] == 0) { 045 return 0; 046 } else if (shape[i] < 0) { 047 throw new IllegalArgumentException(String.format( 048 "The %d-th is %d which is not allowed as it is negative", i, shape[i])); 049 } 050 051 dsize *= shape[i]; 052 } 053 054 // check to see if the size is larger than an integer, i.e. we can't allocate it 055 if (dsize > Long.MAX_VALUE) { 056 throw new IllegalArgumentException("Size of the dataset is too large to allocate"); 057 } 058 return (long) dsize; 059 } 060 061 /** 062 * Calculate total number of items in given shape 063 * @param shape dataset shape 064 * @return size 065 */ 066 public static int calcSize(final int[] shape) { 067 long lsize = calcLongSize(shape); 068 069 // check to see if the size is larger than an integer, i.e. we can't allocate it 070 if (lsize > Integer.MAX_VALUE) { 071 throw new IllegalArgumentException("Size of the dataset is too large to allocate"); 072 } 073 return (int) lsize; 074 } 075 076 /** 077 * Check if size is zero 078 * @param shape dataset shape 079 * @return true if shape is null or any dimension in shape is zero 080 * @since 2.3 081 */ 082 public static boolean isZeroSize(final int[] shape) { 083 if (shape == null) { // special case of null-shaped 084 return true; 085 } 086 087 final int rank = shape.length; 088 if (rank == 0) { // special case of zero-rank shape 089 return false; 090 } 091 092 for (int i = 0; i < rank; i++) { 093 if (shape[i] == 0) { 094 return true; 095 } 096 } 097 098 return false; 099 } 100 101 /** 102 * Check if shapes are broadcast compatible 103 * 104 * @param ashape first shape 105 * @param bshape second shape 106 * @return true if they are compatible 107 */ 108 public static boolean areShapesBroadcastCompatible(final int[] ashape, final int[] bshape) { 109 if (ashape == null || bshape == null) { 110 return ashape == bshape; 111 } 112 113 if (ashape.length < bshape.length) { 114 return areShapesBroadcastCompatible(bshape, ashape); 115 } 116 117 for (int a = ashape.length - bshape.length, b = 0; a < ashape.length && b < bshape.length; a++, b++) { 118 if (ashape[a] != bshape[b] && ashape[a] != 1 && bshape[b] != 1) { 119 return false; 120 } 121 } 122 123 return true; 124 } 125 126 /** 127 * Check if shapes are compatible, ignoring extra axes of length 1 128 * 129 * @param ashape first shape 130 * @param bshape second shape 131 * @return true if they are compatible 132 */ 133 public static boolean areShapesCompatible(final int[] ashape, final int[] bshape) { 134 if (ashape == null || bshape == null) { 135 return ashape == bshape; 136 } 137 138 List<Integer> alist = new ArrayList<Integer>(); 139 140 for (int a : ashape) { 141 if (a > 1) alist.add(a); 142 } 143 144 final int imax = alist.size(); 145 int i = 0; 146 for (int b : bshape) { 147 if (b == 1) 148 continue; 149 if (i >= imax || b != alist.get(i++)) 150 return false; 151 } 152 153 return i == imax; 154 } 155 156 /** 157 * Check if shapes are compatible but skip axis 158 * 159 * @param ashape first shape 160 * @param bshape second shape 161 * @param axis to skip 162 * @return true if they are compatible 163 */ 164 public static boolean areShapesCompatible(final int[] ashape, final int[] bshape, final int axis) { 165 if (ashape == null || bshape == null) { 166 return ashape == bshape; 167 } 168 169 if (ashape.length != bshape.length) { 170 return false; 171 } 172 173 final int rank = ashape.length; 174 for (int i = 0; i < rank; i++) { 175 if (i != axis && ashape[i] != bshape[i]) { 176 return false; 177 } 178 } 179 return true; 180 } 181 182 /** 183 * Remove dimensions of 1 in given shape - from both ends only, if true 184 * 185 * @param shape dataset shape 186 * @param onlyFromEnds if true, trim ends 187 * @return newly squeezed shape (or original if unsqueezed) 188 */ 189 public static int[] squeezeShape(final int[] shape, boolean onlyFromEnds) { 190 int unitDims = 0; 191 int rank = shape.length; 192 int start = 0; 193 194 if (onlyFromEnds) { 195 int i = rank - 1; 196 for (; i >= 0; i--) { 197 if (shape[i] == 1) { 198 unitDims++; 199 } else { 200 break; 201 } 202 } 203 for (int j = 0; j <= i; j++) { 204 if (shape[j] == 1) { 205 unitDims++; 206 } else { 207 start = j; 208 break; 209 } 210 } 211 } else { 212 for (int i = 0; i < rank; i++) { 213 if (shape[i] == 1) { 214 unitDims++; 215 } 216 } 217 } 218 219 if (unitDims == 0) { 220 return shape; 221 } 222 223 int[] newDims = new int[rank - unitDims]; 224 if (unitDims == rank) 225 return newDims; // zero-rank dataset 226 227 if (onlyFromEnds) { 228 rank = newDims.length; 229 for (int i = 0; i < rank; i++) { 230 newDims[i] = shape[i+start]; 231 } 232 } else { 233 int j = 0; 234 for (int i = 0; i < rank; i++) { 235 if (shape[i] > 1) { 236 newDims[j++] = shape[i]; 237 if (j >= newDims.length) 238 break; 239 } 240 } 241 } 242 243 return newDims; 244 } 245 246 /** 247 * Remove dimension of 1 in given shape 248 * 249 * @param shape dataset shape 250 * @param axis to remove 251 * @return newly squeezed shape 252 */ 253 public static int[] squeezeShape(final int[] shape, int axis) { 254 if (shape == null) { 255 return null; 256 } 257 258 final int rank = shape.length; 259 if (rank == 0) { 260 return new int[0]; 261 } 262 if (axis < 0) { 263 axis += rank; 264 } 265 if (axis < 0 || axis >= rank) { 266 throw new IllegalArgumentException("Axis argument is outside allowed range"); 267 } 268 int[] nshape = new int[rank-1]; 269 for (int i = 0; i < axis; i++) { 270 nshape[i] = shape[i]; 271 } 272 for (int i = axis+1; i < rank; i++) { 273 nshape[i-1] = shape[i]; 274 } 275 return nshape; 276 } 277 278 /** 279 * Get shape from object (array or list supported) 280 * @param obj object 281 * @return shape can be null if obj is null 282 */ 283 public static int[] getShapeFromObject(final Object obj) { 284 if (obj == null) { 285 return null; 286 } 287 288 ArrayList<Integer> lshape = new ArrayList<Integer>(); 289 getShapeFromObj(lshape, obj, 0); 290 291 final int rank = lshape.size(); 292 final int[] shape = new int[rank]; 293 for (int i = 0; i < rank; i++) { 294 shape[i] = lshape.get(i); 295 } 296 297 return shape; 298 } 299 300 /** 301 * Get shape from object 302 * @param ldims 303 * @param obj 304 * @param depth 305 * @return true if there is a possibility of differing lengths 306 */ 307 private static boolean getShapeFromObj(final ArrayList<Integer> ldims, Object obj, int depth) { 308 if (obj == null) 309 return true; 310 311 if (obj instanceof List<?>) { 312 List<?> jl = (List<?>) obj; 313 int l = jl.size(); 314 updateShape(ldims, depth, l); 315 for (int i = 0; i < l; i++) { 316 Object lo = jl.get(i); 317 if (!getShapeFromObj(ldims, lo, depth + 1)) { 318 break; 319 } 320 } 321 return true; 322 } 323 Class<? extends Object> ca = obj.getClass().getComponentType(); 324 if (ca != null) { 325 final int l = Array.getLength(obj); 326 updateShape(ldims, depth, l); 327 if (InterfaceUtils.isElementSupported(ca)) { 328 return true; 329 } 330 for (int i = 0; i < l; i++) { 331 Object lo = Array.get(obj, i); 332 if (!getShapeFromObj(ldims, lo, depth + 1)) { 333 break; 334 } 335 } 336 return true; 337 } else if (obj instanceof IDataset) { 338 int[] s = ((IDataset) obj).getShape(); 339 for (int i = 0; i < s.length; i++) { 340 updateShape(ldims, depth++, s[i]); 341 } 342 return true; 343 } else { 344 return false; // not an array of any type 345 } 346 } 347 348 private static void updateShape(final ArrayList<Integer> ldims, final int depth, final int l) { 349 if (depth >= ldims.size()) { 350 ldims.add(l); 351 } else if (l > ldims.get(depth)) { 352 ldims.set(depth, l); 353 } 354 } 355 356 /** 357 * Get n-D position from given index 358 * @param n absolute index 359 * @param shape dataset shape 360 * @return n-D position 361 */ 362 public static int[] getNDPositionFromShape(int n, int[] shape) { 363 if (shape == null) { 364 return null; 365 } 366 367 int rank = shape.length; 368 if (rank == 0) { 369 return new int[0]; 370 } 371 372 if (rank == 1) { 373 return new int[] { n }; 374 } 375 376 int[] output = new int[rank]; 377 for (rank--; rank > 0; rank--) { 378 output[rank] = n % shape[rank]; 379 n /= shape[rank]; 380 } 381 output[0] = n; 382 383 return output; 384 } 385 386 /** 387 * Get flattened view index of given position 388 * @param shape dataset shape 389 * @param pos 390 * the integer array specifying the n-D position 391 * @return the index on the flattened dataset 392 */ 393 public static int getFlat1DIndex(final int[] shape, final int[] pos) { 394 final int imax = pos.length; 395 if (imax == 0) { 396 return 0; 397 } 398 399 return AbstractDataset.get1DIndexFromShape(shape, pos); 400 } 401 402 /** 403 * This function takes a dataset and checks its shape against another dataset. If they are both of the same size, 404 * then this returns with no error, if there is a problem, then an error is thrown. 405 * 406 * @param g 407 * The first dataset to be compared 408 * @param h 409 * The second dataset to be compared 410 * @throws IllegalArgumentException 411 * This will be thrown if there is a problem with the compatibility 412 */ 413 public static void checkCompatibility(final ILazyDataset g, final ILazyDataset h) throws IllegalArgumentException { 414 if (!areShapesCompatible(g.getShape(), h.getShape())) { 415 throw new IllegalArgumentException("Shapes do not match"); 416 } 417 } 418 419 /** 420 * Check that axis is in range [-rank,rank) 421 * 422 * @param rank number of dimensions 423 * @param axis dimension to check 424 * @return sanitized axis in range [0, rank) 425 * @since 2.1 426 */ 427 public static int checkAxis(int rank, int axis) { 428 if (axis < 0) { 429 axis += rank; 430 } 431 432 if (axis < 0 || axis >= rank) { 433 throw new IllegalArgumentException("Axis " + axis + " given is out of range [0, " + rank + ")"); 434 } 435 return axis; 436 } 437 438 private static int[] convert(Collection<Integer> list) { 439 int[] array = new int[list.size()]; 440 int i = 0; 441 for (Integer l : list) { 442 array[i++] = l; 443 } 444 return array; 445 } 446 447 /** 448 * Check that all axes are in range [-rank,rank) 449 * @param rank number of dimensions 450 * @param axes to skip 451 * @return sanitized axes in range [0, rank) and sorted in increasing order 452 * @since 2.2 453 */ 454 public static int[] checkAxes(int rank, int... axes) { 455 return convert(sanitizeAxes(rank, axes)); 456 } 457 458 /** 459 * Check that all axes are in range [-rank,rank) 460 * @param rank number of dimensions 461 * @param axes to skip 462 * @return sanitized axes in range [0, rank) and sorted in increasing order 463 * @since 2.2 464 */ 465 private static SortedSet<Integer> sanitizeAxes(int rank, int... axes) { 466 SortedSet<Integer> nAxes = new TreeSet<>(); 467 for (int i = 0; i < axes.length; i++) { 468 nAxes.add(checkAxis(rank, axes[i])); 469 } 470 471 return nAxes; 472 } 473 474 /** 475 * @param rank number of dimensions 476 * @param axes to skip 477 * @return remaining axes not given by input 478 * @since 2.2 479 */ 480 public static int[] getRemainingAxes(int rank, int... axes) { 481 SortedSet<Integer> nAxes = sanitizeAxes(rank, axes); 482 483 int[] remains = new int[rank - axes.length]; 484 int j = 0; 485 for (int i = 0; i < rank; i++) { 486 if (!nAxes.contains(i)) { 487 remains[j++] = i; 488 } 489 } 490 return remains; 491 } 492 493 /** 494 * Remove axes from shape 495 * @param shape to use 496 * @param axes to remove 497 * @return reduced shape 498 * @since 2.2 499 */ 500 public static int[] reduceShape(int[] shape, int... axes) { 501 int[] remain = getRemainingAxes(shape.length, axes); 502 for (int i = 0; i < remain.length; i++) { 503 int a = remain[i]; 504 remain[i] = shape[a]; 505 } 506 return remain; 507 } 508 509 /** 510 * Set reduced axes to 1 511 * @param shape input 512 * @param axes to set to 1 513 * @return shape with same rank 514 * @since 2.2 515 */ 516 public static int[] getReducedShapeKeepRank(int[] shape, int... axes) { 517 int[] keep = shape.clone(); 518 axes = checkAxes(shape.length, axes); 519 for (int i : axes) { 520 keep[i] = 1; 521 } 522 return keep; 523 } 524 525 /** 526 * @param a first shape 527 * @param b second shape 528 * @return true if arrays only differs by unit entries 529 * @since 2.2 530 */ 531 public static boolean differsByOnes(int[] a, int[] b) { 532 int aRank = a.length; 533 int bRank = b.length; 534 int ai = 0; 535 int bi = 0; 536 int al = 1; 537 int bl = 1; 538 do { 539 while (ai < aRank && (al = a[ai++]) == 1) { // next non-unit dimension 540 } 541 while (bi < bRank && (bl = b[bi++]) == 1) { 542 } 543 if (al != bl) { 544 return false; 545 } 546 } while (ai < aRank && bi < bRank); 547 548 if (ai == aRank) { 549 while (bi < bRank) { 550 if (b[bi++] != 1) { 551 return false; 552 } 553 } 554 } 555 if (bi == bRank) { 556 while (ai < aRank) { 557 if (a[ai++] != 1) { 558 return false; 559 } 560 } 561 } 562 return true; 563 } 564 565 /** 566 * Calculate the padding difference between two shapes. Padding can be positive (negative) 567 * for added (removed) dimensions. NB positive or negative padding is given after matched 568 * dimensions 569 * @param aShape first shape 570 * @param bShape second shape 571 * @return padding can be null if shapes are equal 572 * @throws IllegalArgumentException if one shape is null but not the other, or if shapes do 573 * not possess common non-unit lengths 574 * @since 2.2 575 */ 576 public static int[] calcShapePadding(int[] aShape, int[] bShape) { 577 if (Arrays.equals(aShape, bShape)) { 578 return null; 579 } 580 581 if (aShape == null || bShape == null) { 582 throw new IllegalArgumentException("If one shape is null then the other must be null too"); 583 } 584 585 if (!differsByOnes(aShape, bShape)) { 586 throw new IllegalArgumentException("Non-unit lengths in shapes must be equal"); 587 } 588 int aRank = aShape.length; 589 int bRank = bShape.length; 590 591 int[] padding; 592 if (aRank == 0 || bRank == 0) { 593 padding = new int[1]; 594 padding[0] = aRank == 0 ? bRank : -aRank; 595 return padding; 596 } 597 598 padding = new int[Math.max(aRank, bRank) + 2]; 599 int ai = 0; 600 int bi = 0; 601 int al = 0; 602 int bl = 0; 603 int pi = 0; 604 int p; 605 boolean aLeft = ai < aRank; 606 boolean bLeft = bi < bRank; 607 while (aLeft && bLeft) { 608 if (aLeft) { 609 al = aShape[ai++]; 610 aLeft = ai < aRank; 611 } 612 if (bLeft) { 613 bl = bShape[bi++]; 614 bLeft = bi < bRank; 615 } 616 if (al != bl) { 617 p = 0; 618 while (al == 1 && aLeft) { 619 al = aShape[ai++]; 620 aLeft = ai < aRank; 621 p--; 622 } 623 while (bl == 1 && bLeft) { 624 bl = bShape[bi++]; 625 bLeft = bi < bRank; 626 p++; 627 } 628 padding[pi++] = p; 629 } 630 if (al == bl) { 631 pi++; 632 } 633 } 634 if (aLeft || bLeft) { 635 p = 0; 636 while (ai < aRank && aShape[ai++] == 1) { 637 p--; 638 } 639 while (bi < bRank && bShape[bi++] == 1) { 640 p++; 641 } 642 padding[pi++] = p; 643 } 644 645 return Arrays.copyOf(padding, pi); 646 } 647 648 static int[] padShape(int[] padding, int nr, int[] oldShape) { 649 if (padding == null) { 650 return oldShape.clone(); 651 } 652 int or = oldShape.length; 653 int[] newShape = new int[nr]; 654 int di = 0; 655 for (int i = 0, si = 0; i < padding.length && si <= or && di < nr; i++) { 656 int c = padding[i]; 657 if (c == 0) { 658 newShape[di++] = oldShape[si++]; 659 } else if (c > 0) { 660 int dim = di + c; 661 while (di < dim) { 662 newShape[di++] = 1; 663 } 664 } else if (c < 0) { 665 si -= c; // remove dimensions by skipping forward in source array (should check that they are unit entries) 666 } 667 } 668 while (di < nr) { 669 newShape[di++] = 1; 670 } 671 return newShape; 672 } 673}