From 3811a6b3b51d19aa792184cf2e1b81e7033e1b07 Mon Sep 17 00:00:00 2001
From: Ferdinand Fischer <ferdinand.fischer@fau.de>
Date: Tue, 25 Feb 2020 15:33:45 +0100
Subject: [PATCH] Fixed quantity.Discrete/on to keep the right order of
 arguments also for 3-dim domains.

---
 +quantity/Discrete.m                | 141 +++++++++++++++-------------
 +quantity/Domain.m                  |  16 +++-
 +unittests/+quantity/testDiscrete.m |  36 ++++++-
 3 files changed, 126 insertions(+), 67 deletions(-)

diff --git a/+quantity/Discrete.m b/+quantity/Discrete.m
index e25c98c..50c86a5 100644
--- a/+quantity/Discrete.m
+++ b/+quantity/Discrete.m
@@ -231,32 +231,40 @@ classdef  (InferiorClasses = {?quantity.Symbolic}) Discrete < handle & matlab.mi
 	end
 	
 	methods (Access = public)
-		function obj_hat = compose(obj, g, varargin)
-			% COMPOSE compose two functions
-			%	OBJ_hat = compose(obj, G, varargin) composes the function
-			%	defined by OBJ with the function given by G. In particular,
-			%		f_hat(z,t) = f( g(z,t) )
-			%	if f(t) = obj, g is G and f_hat is OBJ_hat.
+		function [d, I, d_size] = compositionDomain(obj, g, varargin)
 			
 			assert(isscalar(g));
-			assert(nargin(obj) == 1 );
-			
-			newDomain = g.domain();
-			composeOnDomain = g.on();
+
+			d = g.on();
 			
 			% the evaluation of obj.on( compositionDomain ) is done by:
-			domainSize = size(composeOnDomain);
+			d_size = size(d);
 			
 			% 1) vectorization of the n-d-grid: compositionDomain	
-			composeOnDomain = composeOnDomain(:);
+			d = d(:);
 
 			% 2) then it is sorted in ascending order
-			[composeOnDomain, I] = sort(composeOnDomain);			
+			[d, I] = sort(d);			
 			
 			% verify the domain to be monotonical increasing
-			deltaCOD = diff(composeOnDomain);
+			deltaCOD = diff(d);
 			assert(misc.alln(deltaCOD >= 0), 'The domain for the composition f(g(.)) must be monotonically increasing');
 
+
+		end
+		
+		function obj_hat = compose(obj, g, varargin)
+			% COMPOSE compose two functions
+			%	OBJ_hat = compose(obj, G, varargin) composes the function
+			%	defined by OBJ with the function given by G. In particular,
+			%		f_hat(z,t) = f( g(z,t) )
+			%	if f(t) = obj, g is G and f_hat is OBJ_hat.
+			
+			assert(nargin(obj) == 1 );
+			
+			[composeOnDomain, I, domainSize] = ...
+				obj.compositionDomain(g, varargin{:});
+
 			% check if the composition domain is in the range of definition
 			% of obj.
 			if( obj.domain.lower > composeOnDomain(1) || ...
@@ -264,7 +272,7 @@ classdef  (InferiorClasses = {?quantity.Symbolic}) Discrete < handle & matlab.mi
 			
 				warning('quantity:Discrete:compose', ....
 					'The composition domain is not a subset of obj.domain! The missing values will be extrapolated.');
-			end
+			end			
 			
 			% 3) evaluation on the new grid:
 			newValues = obj.on( composeOnDomain );
@@ -280,7 +288,7 @@ classdef  (InferiorClasses = {?quantity.Symbolic}) Discrete < handle & matlab.mi
 			obj_hat = quantity.Discrete( newValues, ...
 				'name', [obj.name '°' g.name], ...
 				'size', size(obj), ...
-				'domain', newDomain );
+				'domain', g.domain());
 			
 		end
 		
@@ -346,15 +354,17 @@ classdef  (InferiorClasses = {?quantity.Symbolic}) Discrete < handle & matlab.mi
 					gridPermuteIdx = 1:length(myDomain);
 				else
 					assert(numel(myDomain) == numel(obj(1).domain), ['Wrong grid for the evaluation of the object']);
-					[myDomain, gridPermuteIdx] = obj(1).domain.permute(myDomain);
+					% compute the permutation index, in order to bring the
+					% new domain in the same order as the original one.
+					gridPermuteIdx = obj(1).domain.getPermutationIdx(myDomain);
 				end			
 				% get the valueDiscrete data for this object. Apply the
 				% permuted myDomain. Then the obj2value will be evaluated
 				% in the order of the original domain. The permuatation to
 				% the new order will be done in the next step.
-				value = obj.obj2value(myDomain(gridPermuteIdx));
-				
-				value = permute(reshape(value, [cellfun(@(v) numel(v), {myDomain(gridPermuteIdx).grid}), size(obj)]), ...
+				originalOrderedDomain(gridPermuteIdx) = myDomain;
+				value = obj.obj2value(originalOrderedDomain);
+				value = permute(reshape(value, [cellfun(@(v) numel(v), {originalOrderedDomain.grid}), size(obj)]), ...
 					[gridPermuteIdx, numel(gridPermuteIdx)+(1:ndims(obj))]);
 			end
 		end
@@ -1978,36 +1988,6 @@ classdef  (InferiorClasses = {?quantity.Symbolic}) Discrete < handle & matlab.mi
 				'name', [A(1).name, '+', B(1).name]);
 		end
 		
-		function [valDiscrete] = expandValueDiscrete(obj, newDomain)
-			% EXPANDVALUEDISCRETE
-			%	[valDiscrete, gridNameSorted] = ...
-			%       expandValueDiscrete(obj, gridIndex, valDiscrete)
-			%	Expanses the value of obj, so that
-			
-			
-			gridNameJoined  = {newDomain.name};
-			gridJoinedLength = newDomain.gridLength;
-			
-			% get the index of obj.grid in the joined grid
-			[~, logicalIdx] = newDomain.gridIndex({obj(1).domain.name});
-			
-			valDiscrete = obj.on( newDomain(logicalIdx) );
-			oldDim = ndims(valDiscrete);
-			valDiscrete = permute(valDiscrete, [(1:sum(~logicalIdx)) + oldDim, 1:oldDim] );
-			valDiscrete = repmat(valDiscrete, [gridJoinedLength(~logicalIdx), ones(1, ndims(valDiscrete))]);
-			%
-			valDiscrete = reshape(valDiscrete, ...
-				[gridJoinedLength(~logicalIdx), gridJoinedLength(logicalIdx), size(obj)]);
-			
-			% permute valDiscrete such that grids are in the order specified
-			% by gridNameJoined.
-			gridIndex = 1:numel(logicalIdx);
-			gridOrder = [gridIndex(~logicalIdx), gridIndex(logicalIdx)];
-			valDiscrete = permute(valDiscrete, [gridOrder, numel(logicalIdx)+(1:ndims(obj))]);
-			
-		end
-		
-		
 		function C = minus(A, B)
 			% minus uses plus()
 			C = A + (-B);
@@ -2208,25 +2188,25 @@ classdef  (InferiorClasses = {?quantity.Symbolic}) Discrete < handle & matlab.mi
 	%%
 	methods (Static)
 		
-		function P = zeros(valueSize, grid, varargin)
+		function P = zeros(valueSize, domain, varargin)
 			%ZEROS initializes an zero quantity.Discrete object
-			%   P = zeros(VALUESIZE, GRID) returns a quantity.Discrete
-			%   object that has only zero entries on a grid which is
-			%   defined by the cell GRID and the value size defined by the
-			%   size-vector VALUESIZE
-			%       P = quantity.Discrete([n,m], {linspace(0,1)',
-			%                                       linspace(0,10)});
-			%       creates an (n times m) zero quantity.Discrete on the
-			%       grid (0,1) x (0,10)
-			%
 			%	P = zeros(VALUESIZE, DOMAIN) creates a matrix of size
 			%	VALUESIZE on the DOMAIN with zero entries.
-			if ~iscell(grid)
-				grid = {grid};
+			
+			myParser = misc.Parser();
+			myParser.addParameter('gridName', '');
+			myParser.parse(varargin{:});
+			
+			if ~isa(domain, 'quantity.Domain')
+				% if the input parameter DOMAIN is not a quantity.Domain
+				% object. It is assumed that it is a grid.
+				grids = misc.ensureIsCell(domain);
+				gridNames = misc.ensureIsCell( myParser.Results.gridName );
+				domain = quantity.Domain.gridCells2domain(grids, gridNames);
 			end
-			gridSize = cellfun('length', grid);
-			O = zeros([gridSize(:); valueSize(:)]');
-			P = quantity.Discrete(O, 'size', valueSize, 'grid', grid, varargin{:});
+			
+			O = zeros([domain.gridLength, valueSize(:)']);
+			P = quantity.Discrete(O, 'size', valueSize, 'domain', domain, varargin{:});
 		end
 		
 		function q = value2cell(value, gridSize, valueSize)
@@ -2276,6 +2256,37 @@ classdef  (InferiorClasses = {?quantity.Symbolic}) Discrete < handle & matlab.mi
 	end %% (Static)
 	methods(Access = protected)
 		
+		function [valDiscrete] = expandValueDiscrete(obj, newDomain)
+			% EXPANDVALUEDISCRETE expand the discrete value on the
+			% newDomain
+			%	[valDiscrete] = ...
+			%       expandValueDiscrete(obj, newDomain) expands the
+			%       discrete values on a new domain. So that a function
+			%			f(z,t) = f(z) + f(t)
+			%	can be computed.
+		
+			gridJoinedLength = newDomain.gridLength;
+			
+			% get the index of obj.grid in the joined grid
+			[idx, logicalIdx] = newDomain.gridIndex({obj(1).domain.name});
+			% evaluate the 
+			valDiscrete = obj.on( newDomain(logicalIdx) );
+			oldDim = ndims(valDiscrete);
+			valDiscrete = permute(valDiscrete, [(1:sum(~logicalIdx)) + oldDim, 1:oldDim] );
+			valDiscrete = repmat(valDiscrete, [gridJoinedLength(~logicalIdx), ones(1, ndims(valDiscrete))]);
+			%
+			valDiscrete = reshape(valDiscrete, ...
+				[gridJoinedLength(~logicalIdx), gridJoinedLength(logicalIdx), size(obj)]);
+			
+			% permute valDiscrete such that grids are in the order specified
+			% by gridNameJoined.
+			gridIndex = 1:numel(logicalIdx);
+			gridOrder = [gridIndex(~logicalIdx), gridIndex(logicalIdx)];
+			gridIndex(gridOrder) = 1:numel(logicalIdx);
+			
+			valDiscrete = permute(valDiscrete, [gridIndex, numel(logicalIdx)+(1:ndims(obj))]);
+		end
+		
 		function result = diff_inner(obj, k, diffGridName)
 			gridSelector = strcmp(obj(1).gridName, diffGridName);
 			gridSelectionIndex = find(gridSelector);
diff --git a/+quantity/Domain.m b/+quantity/Domain.m
index f17d5a0..e338050 100644
--- a/+quantity/Domain.m
+++ b/+quantity/Domain.m
@@ -213,7 +213,7 @@ classdef Domain < handle & matlab.mixin.CustomDisplay
 			end
 		end % ndgrid()
 		
-		function [newDomain, idx] = permute(obj, order)
+		function [idx, newDomain] = getPermutationIdx(obj, order)
 			
 			if isa(order, 'quantity.Domain')
 				names = {order.name};
@@ -309,6 +309,20 @@ classdef Domain < handle & matlab.mixin.CustomDisplay
 	end % Access = protected
 	
 	methods (Static)
+		
+		function d = gridCells2domain(grids, gridNames)
+			grids = misc.ensureIsCell(grids);
+			gridNames = misc.ensureIsCell(gridNames);
+			
+			assert(length( grids ) == length(gridNames))
+
+			d = quantity.Domain.empty();
+			
+			for k = 1:length(grids)
+				d = [d quantity.Domain('grid', grids{k}, 'name', gridNames{k})];
+			end
+		end
+		
 		function g = defaultGrid(gridSize, name)
 			
 			if nargin == 1
diff --git a/+unittests/+quantity/testDiscrete.m b/+unittests/+quantity/testDiscrete.m
index 65e6614..cb86570 100644
--- a/+unittests/+quantity/testDiscrete.m
+++ b/+unittests/+quantity/testDiscrete.m
@@ -587,7 +587,19 @@ testCase.verifyEqual(permute(createTestData(linspace(0, 1, 21), linspace(0, 1, 2
 		value(:,:,2,3) = 1+zeros(numel(gridVecA), numel(gridVecB));
 	end
 
+%% test on on 3-dim domain
+z = quantity.Domain('grid', 0:3, 'name', 'z');
+t = quantity.Domain('grid', 0:4, 'name', 't');
+x = quantity.Domain('grid', 0:5, 'name', 'x');
 
+Z = quantity.Discrete( z.grid, 'domain', z);
+T = quantity.Discrete( t.grid, 'domain', t);
+X = quantity.Discrete( x.grid, 'domain', x);
+
+ZTX = Z+T+X;
+XTZ = X+T+Z;
+
+testCase.verifyEqual( ZTX.on([t z x]), XTZ.on([t z x]), 'AbsTol', 1e-12);
 
 end
 
@@ -1048,7 +1060,6 @@ eMatReference = zGrid + zetaGrid;
 testCase.verifyEqual(numel(eMat), numel(eMatReference));
 testCase.verifyEqual(eMat(:), eMatReference(:));
 
-
 %% addition with constant values
 testCase.verifyEqual(permute([a b], [1 3 2]), AB.on());
 
@@ -1064,6 +1075,29 @@ cAB2 = [1 2; 3 4] + AB2;
 testCase.verifyEqual(AB2c.on(), tst)
 testCase.verifyEqual(cAB2.on(), tst)
 
+%% test plus on different domains
+z = quantity.Domain('grid', 0:3, 'name', 'z');
+t = quantity.Domain('grid', 0:4, 'name', 't');
+x = quantity.Domain('grid', 0:5, 'name', 'x');
+
+T = quantity.Discrete( t.grid, 'domain', t);
+Z = quantity.Discrete( z.grid, 'domain', z);
+X = quantity.Discrete( x.grid, 'domain', x);
+
+TZX = T+Z+X;
+
+TZ1 = T + Z + 1;
+TZX1 = subs(TZX, 'x', 1);
+
+testCase.verifyEqual( TZX1.on(), TZ1.on(), 'AbsTol', 1e-12 );
+testCase.verifyEqual( on( T + 0.5 + X  ), on( subs( TZX, 'z', 0.5) ), 'AbsTol', 1e-12 );
+
+XTZ = X+T+Z;
+
+TZX + XTZ;
+
+%TODO
+
 end
 
 function testInit(testCase)
-- 
GitLab