001/*******************************************************************************
002 * Copyright (c) 2016 Diamond Light Source Ltd. and others.
003 * All rights reserved. This program and the accompanying materials
004 * are made available under the terms of the Eclipse Public License v1.0
005 * which accompanies this distribution, and is available at
006 * http://www.eclipse.org/legal/epl-v10.html
007 *
008 * Contributors:
009 *     Diamond Light Source Ltd - initial API and implementation
010 *******************************************************************************/
011package org.eclipse.january.dataset;
012
013import java.lang.reflect.Array;
014import java.util.ArrayList;
015import java.util.Arrays;
016import java.util.Collection;
017import java.util.List;
018import java.util.SortedSet;
019import java.util.TreeSet;
020
021public class ShapeUtils {
022
023        private ShapeUtils() {
024        }
025
026        /**
027         * Calculate total number of items in given shape
028         * @param shape dataset shape
029         * @return size
030         */
031        public static long calcLongSize(final int[] shape) {
032                if (shape == null) { // special case of null-shaped
033                        return 0;
034                }
035
036                final int rank = shape.length;
037                if (rank == 0) { // special case of zero-rank shape 
038                        return 1;
039                }
040        
041                double dsize = 1.0;
042                for (int i = 0; i < rank; i++) {
043                        // make sure the indexes isn't zero or negative
044                        if (shape[i] == 0) {
045                                return 0;
046                        } else if (shape[i] < 0) {
047                                throw new IllegalArgumentException(String.format(
048                                                "The %d-th is %d which is not allowed as it is negative", i, shape[i]));
049                        }
050        
051                        dsize *= shape[i];
052                }
053        
054                // check to see if the size is larger than an integer, i.e. we can't allocate it
055                if (dsize > Long.MAX_VALUE) {
056                        throw new IllegalArgumentException("Size of the dataset is too large to allocate");
057                }
058                return (long) dsize;
059        }
060
061        /**
062         * Calculate total number of items in given shape
063         * @param shape dataset shape
064         * @return size
065         */
066        public static int calcSize(final int[] shape) {
067                long lsize = calcLongSize(shape);
068        
069                // check to see if the size is larger than an integer, i.e. we can't allocate it
070                if (lsize > Integer.MAX_VALUE) {
071                        throw new IllegalArgumentException("Size of the dataset is too large to allocate");
072                }
073                return (int) lsize;
074        }
075
076        /**
077         * Check if size is zero
078         * @param shape dataset shape
079         * @return true if shape is null or any dimension in shape is zero
080         * @since 2.3
081         */
082        public static boolean isZeroSize(final int[] shape) {
083                if (shape == null) { // special case of null-shaped
084                        return true;
085                }
086
087                final int rank = shape.length;
088                if (rank == 0) { // special case of zero-rank shape
089                        return false;
090                }
091        
092                for (int i = 0; i < rank; i++) {
093                        if (shape[i] == 0) {
094                                return true;
095                        }
096                }
097        
098                return false;
099        }
100
101        /**
102         * Check if shapes are broadcast compatible
103         * 
104         * @param ashape first shape
105         * @param bshape second shape
106         * @return true if they are compatible
107         */
108        public static boolean areShapesBroadcastCompatible(final int[] ashape, final int[] bshape) {
109                if (ashape == null || bshape == null) {
110                        return ashape == bshape;
111                }
112
113                if (ashape.length < bshape.length) {
114                        return areShapesBroadcastCompatible(bshape, ashape);
115                }
116        
117                for (int a = ashape.length - bshape.length, b = 0; a < ashape.length && b < bshape.length; a++, b++) {
118                        if (ashape[a] != bshape[b] && ashape[a] != 1 && bshape[b] != 1) {
119                                return false;
120                        }
121                }
122        
123                return true;
124        }
125
126        /**
127         * Check if shapes are compatible, ignoring extra axes of length 1
128         * 
129         * @param ashape first shape
130         * @param bshape second shape
131         * @return true if they are compatible
132         */
133        public static boolean areShapesCompatible(final int[] ashape, final int[] bshape) {
134                if (ashape == null || bshape == null) {
135                        return ashape == bshape;
136                }
137
138                List<Integer> alist = new ArrayList<Integer>();
139        
140                for (int a : ashape) {
141                        if (a > 1) alist.add(a);
142                }
143        
144                final int imax = alist.size();
145                int i = 0;
146                for (int b : bshape) {
147                        if (b == 1)
148                                continue;
149                        if (i >= imax || b != alist.get(i++))
150                                return false;
151                }
152        
153                return i == imax;
154        }
155
156        /**
157         * Check if shapes are compatible but skip axis
158         * 
159         * @param ashape first shape
160         * @param bshape second shape
161         * @param axis to skip
162         * @return true if they are compatible
163         */
164        public static boolean areShapesCompatible(final int[] ashape, final int[] bshape, final int axis) {
165                if (ashape == null || bshape == null) {
166                        return ashape == bshape;
167                }
168
169                if (ashape.length != bshape.length) {
170                        return false;
171                }
172        
173                final int rank = ashape.length;
174                for (int i = 0; i < rank; i++) {
175                        if (i != axis && ashape[i] != bshape[i]) {
176                                return false;
177                        }
178                }
179                return true;
180        }
181
182        /**
183         * Remove dimensions of 1 in given shape - from both ends only, if true
184         * 
185         * @param shape dataset shape
186         * @param onlyFromEnds if true, trim ends
187         * @return newly squeezed shape (or original if unsqueezed)
188         */
189        public static int[] squeezeShape(final int[] shape, boolean onlyFromEnds) {
190                int unitDims = 0;
191                int rank = shape.length;
192                int start = 0;
193        
194                if (onlyFromEnds) {
195                        int i = rank - 1;
196                        for (; i >= 0; i--) {
197                                if (shape[i] == 1) {
198                                        unitDims++;
199                                } else {
200                                        break;
201                                }
202                        }
203                        for (int j = 0; j <= i; j++) {
204                                if (shape[j] == 1) {
205                                        unitDims++;
206                                } else {
207                                        start = j;
208                                        break;
209                                }
210                        }
211                } else {
212                        for (int i = 0; i < rank; i++) {
213                                if (shape[i] == 1) {
214                                        unitDims++;
215                                }
216                        }
217                }
218        
219                if (unitDims == 0) {
220                        return shape;
221                }
222        
223                int[] newDims = new int[rank - unitDims];
224                if (unitDims == rank)
225                        return newDims; // zero-rank dataset
226        
227                if (onlyFromEnds) {
228                        rank = newDims.length;
229                        for (int i = 0; i < rank; i++) {
230                                newDims[i] = shape[i+start];
231                        }
232                } else {
233                        int j = 0;
234                        for (int i = 0; i < rank; i++) {
235                                if (shape[i] > 1) {
236                                        newDims[j++] = shape[i];
237                                        if (j >= newDims.length)
238                                                break;
239                                }
240                        }
241                }
242        
243                return newDims;
244        }
245
246        /**
247         * Remove dimension of 1 in given shape
248         * 
249         * @param shape dataset shape
250         * @param axis to remove
251         * @return newly squeezed shape
252         */
253        public static int[] squeezeShape(final int[] shape, int axis) {
254                if (shape == null) {
255                        return null;
256                }
257
258                final int rank = shape.length;
259                if (rank == 0) {
260                        return new int[0];
261                }
262                if (axis < 0) {
263                        axis += rank;
264                }
265                if (axis < 0 || axis >= rank) {
266                        throw new IllegalArgumentException("Axis argument is outside allowed range");
267                }
268                int[] nshape = new int[rank-1];
269                for (int i = 0; i < axis; i++) {
270                        nshape[i] = shape[i];
271                }
272                for (int i = axis+1; i < rank; i++) {
273                        nshape[i-1] = shape[i];
274                }
275                return nshape;
276        }
277
278        /**
279         * Get shape from object (array or list supported)
280         * @param obj object
281         * @return shape can be null if obj is null
282         */
283        public static int[] getShapeFromObject(final Object obj) {
284                if (obj == null) {
285                        return null;
286                }
287
288                ArrayList<Integer> lshape = new ArrayList<Integer>();
289                getShapeFromObj(lshape, obj, 0);
290
291                final int rank = lshape.size();
292                final int[] shape = new int[rank];
293                for (int i = 0; i < rank; i++) {
294                        shape[i] = lshape.get(i);
295                }
296        
297                return shape;
298        }
299
300        /**
301         * Get shape from object
302         * @param ldims
303         * @param obj
304         * @param depth
305         * @return true if there is a possibility of differing lengths
306         */
307        private static boolean getShapeFromObj(final ArrayList<Integer> ldims, Object obj, int depth) {
308                if (obj == null)
309                        return true;
310        
311                if (obj instanceof List<?>) {
312                        List<?> jl = (List<?>) obj;
313                        int l = jl.size();
314                        updateShape(ldims, depth, l);
315                        for (int i = 0; i < l; i++) {
316                                Object lo = jl.get(i);
317                                if (!getShapeFromObj(ldims, lo, depth + 1)) {
318                                        break;
319                                }
320                        }
321                        return true;
322                }
323                Class<? extends Object> ca = obj.getClass().getComponentType();
324                if (ca != null) {
325                        final int l = Array.getLength(obj);
326                        updateShape(ldims, depth, l);
327                        if (InterfaceUtils.isElementSupported(ca)) {
328                                return true;
329                        }
330                        for (int i = 0; i < l; i++) {
331                                Object lo = Array.get(obj, i);
332                                if (!getShapeFromObj(ldims, lo, depth + 1)) {
333                                        break;
334                                }
335                        }
336                        return true;
337                } else if (obj instanceof IDataset) {
338                        int[] s = ((IDataset) obj).getShape();
339                        for (int i = 0; i < s.length; i++) {
340                                updateShape(ldims, depth++, s[i]);
341                        }
342                        return true;
343                } else {
344                        return false; // not an array of any type
345                }
346        }
347
348        private static void updateShape(final ArrayList<Integer> ldims, final int depth, final int l) {
349                if (depth >= ldims.size()) {
350                        ldims.add(l);
351                } else if (l > ldims.get(depth)) {
352                        ldims.set(depth, l);
353                }
354        }
355
356        /**
357         * Get n-D position from given index
358         * @param n absolute index
359         * @param shape dataset shape
360         * @return n-D position
361         */
362        public static int[] getNDPositionFromShape(int n, int[] shape) {
363                if (shape == null) {
364                        return null;
365                }
366
367                int rank = shape.length;
368                if (rank == 0) {
369                        return new int[0];
370                }
371
372                if (rank == 1) {
373                        return new int[] { n };
374                }
375
376                int[] output = new int[rank];
377                for (rank--; rank > 0; rank--) {
378                        output[rank] = n % shape[rank];
379                        n /= shape[rank];
380                }
381                output[0] = n;
382        
383                return output;
384        }
385
386        /**
387         * Get flattened view index of given position
388         * @param shape dataset shape
389         * @param pos
390         *            the integer array specifying the n-D position
391         * @return the index on the flattened dataset
392         */
393        public static int getFlat1DIndex(final int[] shape, final int[] pos) {
394                final int imax = pos.length;
395                if (imax == 0) {
396                        return 0;
397                }
398        
399                return AbstractDataset.get1DIndexFromShape(shape, pos);
400        }
401
402        /**
403         * This function takes a dataset and checks its shape against another dataset. If they are both of the same size,
404         * then this returns with no error, if there is a problem, then an error is thrown.
405         * 
406         * @param g
407         *            The first dataset to be compared
408         * @param h
409         *            The second dataset to be compared
410         * @throws IllegalArgumentException
411         *             This will be thrown if there is a problem with the compatibility
412         */
413        public static void checkCompatibility(final ILazyDataset g, final ILazyDataset h) throws IllegalArgumentException {
414                if (!areShapesCompatible(g.getShape(), h.getShape())) {
415                        throw new IllegalArgumentException("Shapes do not match");
416                }
417        }
418
419        /**
420         * Check that axis is in range [-rank,rank)
421         * 
422         * @param rank number of dimensions
423         * @param axis dimension to check
424         * @return sanitized axis in range [0, rank)
425         * @since 2.1
426         */
427        public static int checkAxis(int rank, int axis) {
428                if (axis < 0) {
429                        axis += rank;
430                }
431        
432                if (axis < 0 || axis >= rank) {
433                        throw new IllegalArgumentException("Axis " + axis + " given is out of range [0, " + rank + ")");
434                }
435                return axis;
436        }
437
438        private static int[] convert(Collection<Integer> list) {
439                int[] array = new int[list.size()];
440                int i = 0;
441                for (Integer l : list) {
442                        array[i++] = l;
443                }
444                return array;
445        }
446
447        /**
448         * Check that all axes are in range [-rank,rank)
449         * @param rank number of dimensions
450         * @param axes to skip
451         * @return sanitized axes in range [0, rank) and sorted in increasing order
452         * @since 2.2
453         */
454        public static int[] checkAxes(int rank, int... axes) {
455                return convert(sanitizeAxes(rank, axes));
456        }
457
458        /**
459         * Check that all axes are in range [-rank,rank)
460         * @param rank number of dimensions
461         * @param axes to skip
462         * @return sanitized axes in range [0, rank) and sorted in increasing order
463         * @since 2.2
464         */
465        private static SortedSet<Integer> sanitizeAxes(int rank, int... axes) {
466                SortedSet<Integer> nAxes = new TreeSet<>(); 
467                for (int i = 0; i < axes.length; i++) {
468                        nAxes.add(checkAxis(rank, axes[i]));
469                }
470
471                return nAxes;
472        }
473
474        /**
475         * @param rank number of dimensions
476         * @param axes to skip
477         * @return remaining axes not given by input
478         * @since 2.2
479         */
480        public static int[] getRemainingAxes(int rank, int... axes) {
481                SortedSet<Integer> nAxes = sanitizeAxes(rank, axes);
482
483                int[] remains = new int[rank - axes.length];
484                int j = 0;
485                for (int i = 0; i < rank; i++) {
486                        if (!nAxes.contains(i)) {
487                                remains[j++] = i;
488                        }
489                }
490                return remains;
491        }
492
493        /**
494         * Remove axes from shape
495         * @param shape to use
496         * @param axes to remove
497         * @return reduced shape
498         * @since 2.2
499         */
500        public static int[] reduceShape(int[] shape, int... axes) {
501                int[] remain = getRemainingAxes(shape.length, axes);
502                for (int i = 0; i < remain.length; i++) {
503                        int a = remain[i];
504                        remain[i] = shape[a];
505                }
506                return remain;
507        }
508
509        /**
510         * Set reduced axes to 1
511         * @param shape input
512         * @param axes to set to 1
513         * @return shape with same rank
514         * @since 2.2
515         */
516        public static int[] getReducedShapeKeepRank(int[] shape, int... axes) {
517                int[] keep = shape.clone();
518                axes = checkAxes(shape.length, axes);
519                for (int i : axes) {
520                        keep[i] = 1;
521                }
522                return keep;
523        }
524
525        /**
526         * @param a first shape
527         * @param b second shape
528         * @return true if arrays only differs by unit entries
529         * @since 2.2
530         */
531        public static boolean differsByOnes(int[] a, int[] b) {
532                int aRank = a.length;
533                int bRank = b.length;
534                int ai = 0;
535                int bi = 0;
536                int al = 1;
537                int bl = 1;
538                do {
539                        while (ai < aRank && (al = a[ai++]) == 1) { // next non-unit dimension
540                        }
541                        while (bi < bRank && (bl = b[bi++]) == 1) {
542                        }
543                        if (al != bl) {
544                                return false;
545                        }
546                } while (ai < aRank && bi < bRank);
547
548                if (ai == aRank) {
549                        while (bi < bRank) {
550                                if (b[bi++] != 1) {
551                                        return false;
552                                }
553                        }
554                }
555                if (bi == bRank) {
556                        while (ai < aRank) {
557                                if (a[ai++] != 1) {
558                                        return false;
559                                }
560                        }
561                }
562                return true;
563        }
564
565        /**
566         * Calculate the padding difference between two shapes. Padding can be positive (negative)
567         * for added (removed) dimensions. NB positive or negative padding is given after matched
568         * dimensions
569         * @param aShape first shape
570         * @param bShape second shape
571         * @return padding can be null if shapes are equal
572         * @throws IllegalArgumentException if one shape is null but not the other, or if shapes do
573         * not possess common non-unit lengths
574         * @since 2.2
575         */
576        public static int[] calcShapePadding(int[] aShape, int[] bShape) {
577                if (Arrays.equals(aShape, bShape)) {
578                        return null;
579                }
580
581                if (aShape == null || bShape == null) {
582                        throw new IllegalArgumentException("If one shape is null then the other must be null too");
583                }
584
585                if (!differsByOnes(aShape, bShape)) {
586                        throw new IllegalArgumentException("Non-unit lengths in shapes must be equal");
587                }
588                int aRank = aShape.length;
589                int bRank = bShape.length;
590
591                int[] padding;
592                if (aRank == 0 || bRank == 0) {
593                        padding = new int[1];
594                        padding[0] = aRank == 0 ? bRank : -aRank;
595                        return padding;
596                }
597
598                padding = new int[Math.max(aRank, bRank) + 2];
599                int ai = 0;
600                int bi = 0;
601                int al = 0;
602                int bl = 0;
603                int pi = 0;
604                int p;
605                boolean aLeft = ai < aRank;
606                boolean bLeft = bi < bRank;
607                while (aLeft && bLeft) {
608                        if (aLeft) {
609                                al = aShape[ai++];
610                                aLeft = ai < aRank;
611                        }
612                        if (bLeft) {
613                                bl = bShape[bi++];
614                                bLeft = bi < bRank;
615                        }
616                        if (al != bl) {
617                                p = 0;
618                                while (al == 1 && aLeft) {
619                                        al = aShape[ai++];
620                                        aLeft = ai < aRank;
621                                        p--;
622                                }
623                                while (bl == 1 && bLeft) {
624                                        bl = bShape[bi++];
625                                        bLeft = bi < bRank;
626                                        p++;
627                                }
628                                padding[pi++] = p;
629                        }
630                        if (al == bl) {
631                                pi++;
632                        }
633                }
634                if (aLeft || bLeft) {
635                        p = 0;
636                        while (ai < aRank && aShape[ai++] == 1) {
637                                p--;
638                        }
639                        while (bi < bRank && bShape[bi++] == 1) {
640                                p++;
641                        }
642                        padding[pi++] = p;
643                }
644
645                return Arrays.copyOf(padding, pi);
646        }
647
648        static int[] padShape(int[] padding, int nr, int[] oldShape) {
649                if (padding == null) {
650                        return oldShape.clone();
651                }
652                int or = oldShape.length;
653                int[] newShape = new int[nr];
654                int di = 0;
655                for (int i = 0, si = 0; i < padding.length && si <= or && di < nr; i++) {
656                        int c = padding[i];
657                        if (c == 0) {
658                                newShape[di++] = oldShape[si++];
659                        } else if (c > 0) {
660                                int dim = di + c;
661                                while (di < dim) {
662                                        newShape[di++] = 1;
663                                }
664                        } else if (c < 0) {
665                                si -= c; // remove dimensions by skipping forward in source array (should check that they are unit entries)
666                        }
667                }
668                while (di < nr) {
669                        newShape[di++] = 1;
670                }
671                return newShape;
672        }
673}