-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdp_node_io_merge.m
executable file
·199 lines (143 loc) · 5.89 KB
/
dp_node_io_merge.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
classdef dp_node_io_merge < dp_node
% merges previous nodes
%
% will be called by a dpm iter block
%
% consider using items
properties
previous_nodes = {};
end
methods
function obj = dp_node_io_merge(nodes)
obj.previous_nodes = nodes;
end
function previous_outputs = get_iterable(obj)
% in deep mode, return leftmost arm always
% (xxx: subject to change)
if (obj.opt.deep_mode)
previous_outputs = obj.previous_nodes{1}.get_iterable();
return;
end
% grab previous output from each previous node
list_of_outputs = cell(size(obj.previous_nodes));
for c = 1:numel(obj.previous_nodes)
tmp = obj.previous_nodes{c}.run(obj.opt.iter_mode, obj.opt);
list_of_outputs{c} = tmp;
end
% keep only those outputs where the ids intersect
previous_outputs = dp_node_io_merge.intersect_outputs(list_of_outputs);
% rename (legacy)
previous_outputs = dp_node_io_merge.rename_outputs(previous_outputs, ...
obj.previous_nodes);
% report on outcome
obj.log(0, '%t--> Merging outputs resulted in %i items', ...
numel(previous_outputs));
end
function output = run_inner(obj, po)
% assume we are in deep mode, otherwise do usual one
if (~obj.opt.deep_mode)
output = run_inner@dp_node(obj, po);
return;
end
% grab previous output from each previous node
pos = cell(size(obj.previous_nodes));
for c = 1:numel(obj.previous_nodes)
pos{c} = {obj.previous_nodes{c}.run_inner(po)};
end
% keep only those outputs where the ids intersect
pos = dp_node_io_merge.intersect_outputs(pos);
% rename (legacy)
pos = dp_node_io_merge.rename_outputs(pos, obj.previous_nodes);
output = pos{1};
%output = run_inner@dp_node(obj, po);
end
function obj = update(obj, varargin) % update involved nodes
obj = update@dp_node_base(obj, varargin{:});
for c = 1:numel(obj.previous_nodes)
obj.previous_nodes{c}.update(varargin{:});
end
end
function [status, f, age] = input_exist(obj, input)
status = []; f = []; age = []; % implement later
%[status, f, age] = obj.io_exist2(input, obj.input_test);
end
function [status, f, age] = output_exist(obj, output)
status = []; f = []; age = [];
%[status, f, age] = obj.io_exist2(output, obj.output_test);
end
end
methods (Static)
% select intersect of outputs
function outputs = intersect_outputs(list_of_outputs)
% Pull out a list of id:s
ids = cell(size(list_of_outputs));
for c = 1:numel(list_of_outputs)
ids{c} = cellfun(@(x) x.id, list_of_outputs{c}, 'uniformoutput', 0);
if (numel(ids{c}) ~= numel(unique(ids{c})))
obj.log(0, 'non unique ids detected in dp_node_merge');
error('assuming unique ids for the merge to work');
end
end
% Find unique ids
unique_ids = unique([ids{:}]);
% find
ind = zeros(numel(unique_ids), numel(list_of_outputs));
for i = 1:numel(unique_ids)
for j = 1:numel(list_of_outputs)
tmp = cellfun(@(x)strcmp(x.id,unique_ids{i}), ...
list_of_outputs{j});
if (sum(tmp) == 0)
continue;
end
ind(i,j) = find(tmp);
end
end
% keep only those with occurrances in each
tmp = sum(ind > 0, 2) == numel(list_of_outputs);
unique_ids = unique_ids(tmp);
ind = ind(tmp, :);
% assemble the intersect
outputs = cell(1, numel(unique_ids));
for i = 1:numel(unique_ids)
tmp = struct('id', unique_ids{i});
for j = 1:numel(list_of_outputs)
tmp.output{j} = list_of_outputs{j}{ind(i,j)};
end
outputs{i} = tmp;
end
end
function outputs = rename_outputs(inputs, nodes)
% grab node names
names = cell(size(nodes));
for c = 1:numel(nodes)
names{c} = nodes{c}.name;
end
% we need unique names
if (numel(names) ~= numel(unique(names)))
error('merging previous nodes requires unique node names');
end
% rename fields
outputs = cell(size(inputs));
for i = 1:numel(inputs)
outputs{i}.id = inputs{i}.id;
for j = 1:numel(inputs{i}.output)
tmp = inputs{i}.output{j};
f = fieldnames(tmp);
for k = 1:numel(f)
outputs{i}.([names{j} '_' f{k}]) = tmp.(f{k});
end
end
% set required fields to that of first node to merge
outputs{i}.bp = inputs{i}.output{1}.bp;
outputs{i}.id = inputs{i}.output{1}.id;
% Look through outouts for an op, take the first you find
for k = 1:numel(inputs{i}.output)
if (isfield(inputs{i}.output{k}, 'op'))
outputs{i}.op = inputs{i}.output{k}.op;
break;
end
end
end
end
end
end