Commit b69997f4 authored by Jakob Gabriel's avatar Jakob Gabriel
Browse files

quantity.Discrete bugfixes:

- solveAlgebraic: adjusted to change of changeGrid from call-by-reference to call-by-value
- diff(): now supports specification of various variables for which the derivative should be applied. Ensured consistency with Symbolic. Maybe a more elegant use of gradient() is possible.
- unittests: adjusted for now behavior of diff()

quantity.Symbolic bugfixes:
- diff(): now supports specification of various variables for which the derivative should be applied. Ensured consistency with Discrete.
- unittests: adjusted for now behavior of diff()
parent 3d538332
......@@ -562,7 +562,7 @@ classdef (InferiorClasses = {?quantity.Symbolic, ?quantity.Operator}) Discrete
min(find(gridSelector, 1, 'last')+1, numel(gridSelector))]) = 1;
limitedGrid = obj(1).grid{1}(gridSelector);
objCopy = obj.copy();
objCopy.changeGrid({limitedGrid}, gridName);
objCopy = objCopy.changeGrid({limitedGrid}, gridName);
objInverseTemp = objCopy.invert(gridName);
else
objInverseTemp = obj.invert(gridName);
......@@ -1371,7 +1371,7 @@ classdef (InferiorClasses = {?quantity.Symbolic, ?quantity.Operator}) Discrete
% index as "idx" and its logical index as "log"
if nargin == 1
names = obj(1).gridName();
names = obj(1).gridName;
end
if ~iscell(names)
......@@ -1395,11 +1395,14 @@ classdef (InferiorClasses = {?quantity.Symbolic, ?quantity.Operator}) Discrete
end
function result = diff(obj, k, gridName)
function result = diff(obj, k, diffGridName)
% diff applies the kth-derivative for the variable specified with
% the input gridName to the obj.
if nargin == 1
k = 1;
% the input gridName to the obj. If no gridName is specified, then diff
% applies the derivative w.r.t. to all gridNames.
if nargin == 1 || isempty(k)
k = 1; % by default, only one derivatve per diffGridName is applied
else
assert(isnumeric(k) && (round(k) == k))
end
if obj.isConstant && isempty(obj(1).gridName)
......@@ -1410,42 +1413,57 @@ classdef (InferiorClasses = {?quantity.Symbolic, ?quantity.Operator}) Discrete
return
end
if nargin <= 2
gridName = obj(1).gridName{:};
if nargin < 3 % if no diffGridName is specified, then the derivative
% w.r.t. to all gridNames is applied
diffGridName = obj(1).gridName;
end
% if a higher order derivative is requested, call the function
% recursivly until the first-order derivative is reached
if k > 1
assert(isnumeric(k))
obj = obj.diff(k-1, gridName);
end
gridSelector = strcmp(obj(1).gridName, gridName);
gridSelectionIndex = find(gridSelector);
spacing = gradient(obj(1).grid{gridSelectionIndex}, 1);
assert(numeric.near(spacing, spacing(1)), ...
'diff is currently only implemented for equidistant grid');
permutationVector = 1 : (numel(obj(1).grid)+ndims(obj));
objDiscrete = permute(obj.on(), ...
[permutationVector(gridSelectionIndex), ...
permutationVector(permutationVector ~= gridSelectionIndex)]);
if size(objDiscrete, 2) == 1
derivativeDiscrete = gradient(objDiscrete, spacing(1));
% diff for each element of diffGridName (this is rather inefficient,
% but an easy implementation of the specification)
if iscell(diffGridName) || isempty(diffGridName)
if numel(diffGridName) == 0 || isempty(diffGridName)
result = copy(obj);
else
result = obj.diff(k, diffGridName{1}); % init result
for it = 2 : numel(diffGridName)
result = result.diff(k, diffGridName{it});
end
end
else
[~, derivativeDiscrete] = gradient(objDiscrete, spacing(1));
gridSelector = strcmp(obj(1).gridName, diffGridName);
gridSelectionIndex = find(gridSelector);
spacing = gradient(obj(1).grid{gridSelectionIndex}, 1);
assert(numeric.near(spacing, spacing(1)), ...
'diff is currently only implemented for equidistant grid');
permutationVector = 1 : (numel(obj(1).grid)+ndims(obj));
objDiscrete = permute(obj.on(), ...
[permutationVector(gridSelectionIndex), ...
permutationVector(permutationVector ~= gridSelectionIndex)]);
if size(objDiscrete, 2) == 1
derivativeDiscrete = gradient(objDiscrete, spacing(1));
else
[~, derivativeDiscrete] = gradient(objDiscrete, spacing(1));
end
rePermutationVector = [2:(gridSelectionIndex), ...
1, (gridSelectionIndex+1):ndims(derivativeDiscrete)];
result = quantity.Discrete(...
permute(derivativeDiscrete, rePermutationVector), ...
'size', size(obj), 'grid', obj(1).grid, ...
'gridName', obj(1).gridName, ...
'name', ['(d_{', diffGridName, '}', obj(1).name, ')']);
if k > 1
% % if a higher order derivative is requested, call the function
% % recursivly until the first-order derivative is reached
result = result.diff(k-1, diffGridName);
end
end
rePermutationVector = [2:(gridSelectionIndex), ...
1, (gridSelectionIndex+1):ndims(derivativeDiscrete)];
result = quantity.Discrete(...
permute(derivativeDiscrete, rePermutationVector), ...
'size', size(obj), 'grid', obj(1).grid, ...
'gridName', obj(1).gridName, ...
'name', ['(d_{', gridName, '}', obj(1).name, ')']);
end
function I = int(obj, varargin)
......
......@@ -493,22 +493,55 @@ classdef Symbolic < quantity.Function
'name', ['sqrtm(', x(1).name, ')']);
end
function D = diff(obj, k, gridName)
if nargin == 1
k = 1;
function result = diff(obj, k, diffGridName)
% diff applies the kth-derivative for the variable specified with
% the input gridName to the obj. If no gridName is specified, then diff
% applies the derivative w.r.t. to all gridNames / variables.
if nargin == 1 || isempty(k)
k = 1; % by default, only one derivatve per diffGridName is applied
end
if nargin == 3
error('Not yet implemented')
if nargin <= 2 % if no diffGridName is specified, then the derivative
% w.r.t. to all gridNames is applied
diffGridName = obj(1).gridName;
end
% TODO: specify for which variable it should be differentiated
D = obj.copy();
[D.name] = deal(['(d_{' char(obj(1).variable) '} ' D(1).name, ')']);
[D.valueDiscrete] = deal([]);
for l = 1:numel(obj)
D(l).valueSymbolic = diff(obj(l).valueSymbolic, k);
D(l).valueContinuous = obj.setValueContinuous(D(l).valueSymbolic, obj(1).variable);
result = obj.copy();
if iscell(diffGridName)
for it = 1 : numel(diffGridName)
result = result.diff(k, diffGridName{it});
end
else
diffVariable = obj.gridName2variable(diffGridName);
[result.name] = deal(['(d_{' char(diffVariable) '} ' result(1).name, ')']);
[result.valueDiscrete] = deal([]);
for l = 1:numel(obj)
result(l).valueSymbolic = diff(obj(l).valueSymbolic, diffVariable, k);
result(l).valueContinuous = obj.setValueContinuous(result(l).valueSymbolic, obj(1).variable);
end
end
end
function thisVariable = gridName2variable(obj, thisGridName)
% this method returns the variable thisVariable stored in obj(1).variable
% in the order specified by thisGridName. If thisGridName is a char, then
% only one variable is returned. If thisGridName is a cell-array, then a
% cell array of variables is returned.
if ischar(thisGridName)
assert(any(strcmp(obj(1).gridName, thisGridName)), ...
['The gridName ', thisGridName, ' is not a gridName of this Quantity']);
variableNames = arrayfun(@(v) char(v), obj(1).variable, 'UniformOutput', false);
selectVariable = strcmp(variableNames, thisGridName);
assert(sum(selectVariable) == 1);
thisVariable = obj(1).variable(selectVariable);
elseif iscell(thisGridName)
thisVariable = cell(size(thisGridName));
for it = 1 : numel(thisGridName)
thisVariable{it} = obj.gridName2variable(thisGridName{it});
end
else
error(['The input gridName of gridName2variable() must be a char-array', ...
' or a cell-array']);
end
end
......
......@@ -297,22 +297,25 @@ myQuantity = quantity.Discrete(cat(3, 2*ones(11, 21), zNdgrid, zetaNdgrid), ...
'gridName', {'z', 'zeta'}, 'name', 'constant', 'size', [3, 1]);
myQuantityDz = diff(myQuantity, 1, 'z');
myQuantityDzeta = diff(myQuantity, 1, 'zeta');
myQuantityDzeta2 = diff(myQuantity, 1);
myQuantityDZzeta = diff(myQuantity, 1);
myQuantityDZzeta2 = diff(myQuantity, 1, {'z', 'zeta'});
testCase.verifyEqual(myQuantityDZzeta.on(), myQuantityDZzeta2.on());
% constant
testCase.verifyEqual(myQuantityDz(1).on(), zeros(11, 21));
testCase.verifyEqual(myQuantityDzeta(1).on(), zeros(11, 21));
testCase.verifyEqual(myQuantityDzeta2(1).on(), zeros(11, 21));
testCase.verifyEqual(myQuantityDZzeta(1).on(), zeros(11, 21));
% zNdgrid
testCase.verifyEqual(myQuantityDz(2).on(), ones(11, 21), 'AbsTol', 10*eps);
testCase.verifyEqual(myQuantityDzeta(2).on(), zeros(11, 21), 'AbsTol', 10*eps);
testCase.verifyEqual(myQuantityDzeta2(2).on(), ones(11, 21), 'AbsTol', 10*eps);
testCase.verifyEqual(myQuantityDZzeta(2).on(), zeros(11, 21), 'AbsTol', 10*eps);
% zetaNdgrid
testCase.verifyEqual(myQuantityDz(3).on(), zeros(11, 21), 'AbsTol', 10*eps);
testCase.verifyEqual(myQuantityDzeta(3).on(), ones(11, 21), 'AbsTol', 10*eps);
testCase.verifyEqual(myQuantityDzeta2(3).on(), zeros(11, 21), 'AbsTol', 10*eps);
testCase.verifyEqual(myQuantityDZzeta(3).on(), zeros(11, 21), 'AbsTol', 10*eps);
end
function testOn(testCase)
......
......@@ -4,6 +4,47 @@ function [tests ] = testSymbolic()
tests = functiontests(localfunctions());
end
function testDiffWith2Variables(testCase)
syms z zeta
myGrid = linspace(0, 1, 7);
f = quantity.Symbolic([z*zeta, z; zeta, 1], ...
'variable', {z, zeta}, 'grid', {myGrid, myGrid});
fdz = quantity.Symbolic([zeta, 1; 0, 0], ...
'variable', {z, zeta}, 'grid', {myGrid, myGrid});
fdzeta = quantity.Symbolic([z, 0; 1, 0], ...
'variable', {z, zeta}, 'grid', {myGrid, myGrid});
fRefTotal = 0*f.on();
fRefTotal(:,:,1,1) = 1;
testCase.verifyEqual(on(diff(f)), fRefTotal);
testCase.verifyEqual(on(diff(f, 1, {'z', 'zeta'})), fRefTotal);
testCase.verifyEqual(on(diff(f, 2, {'z', 'zeta'})), 0*fRefTotal);
testCase.verifyEqual(on(diff(f, 1, 'z')), fdz.on());
testCase.verifyEqual(on(diff(f, 1, 'zeta')), fdzeta.on());
end
function testGridName2variable(testCase)
syms z zeta eta e a
myGrid = linspace(0, 1, 7);
obj = quantity.Symbolic([z*zeta, eta*z; e*a, a], ...
'variable', {z zeta eta e a}, ...
'grid', {myGrid, myGrid, myGrid, myGrid, myGrid});
thisGridName = {'eta', 'e', 'z'};
thisVariable = gridName2variable(obj, thisGridName);
variableNames = cellfun(@(v) char(v), thisVariable, 'UniformOutput', false);
testCase.verifyEqual(thisGridName, variableNames);
thisGridName = 'z';
thisVariable = gridName2variable(obj, thisGridName);
testCase.verifyEqual(thisGridName, char(thisVariable));
testCase.verifyError(@() gridName2variable(obj, 't'), '');
testCase.verifyError(@() gridName2variable(obj, 123), '');
end
function testCat(testCase)
syms z zeta
f1 = quantity.Symbolic(1+z*z, 'grid', {linspace(0, 1, 21)}, 'variable', {z}, 'name', 'f1');
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment