-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmmaformatter.py
280 lines (221 loc) · 9.21 KB
/
mmaformatter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
import os
import re
'''
A library for generating strings readable by Mathematica's Get[ ] from Python objects
author: Tyson Jones
tyson.jones@materials.ox.ac.uk
date: 25 Nov 2017
'''
# defaults to get_mma optional arguments
_DEFAULT_PRECISION = 5
_DEFAULT_KEEP_INTS = True
_DEFAULT_KEEP_SYMBOLS = False
_INDENT_CHAR = ' '*4
# formats of MMA objects
_STRING_FORMAT = '"%s"'
_BASE_TEN_FORMAT = '*10^'
_COMPLEX_FORMAT = '%s%s%sI'
_ARRAY_OUTER_FORMAT = '{%s}'
_ARRAY_JOIN_FORMAT = ', '
_DICT_OUTER_FORMAT = '<|\n%s\n|>'
_DICT_JOIN_FORMAT = ',\n'
_DICT_ITEM_FORMAT = '%s -> %s'
def _enquote(thing):
return _STRING_FORMAT % str(thing)
def _get_mma_string(
string,
keep_symbols=_DEFAULT_KEEP_SYMBOLS):
return string if keep_symbols else _enquote(string)
def _get_mma_bool(
boolean):
return str(boolean)
def _get_mma_real(
num,
precision=_DEFAULT_PRECISION,
keep_ints=_DEFAULT_KEEP_INTS):
if keep_ints:
if isinstance(num, int) or type(num).__name__=='long':
return str(num)
if isinstance(num, float) and num.is_integer():
return str(int(num))
return format(num, '.%de' % precision).replace('e', _BASE_TEN_FORMAT)
def _get_mma_complex(
num,
precision=_DEFAULT_PRECISION,
keep_ints=_DEFAULT_KEEP_INTS):
real = _get_mma_real(num.real, precision=precision, keep_ints=keep_ints)
imag = _get_mma_real(num.imag, precision=precision, keep_ints=keep_ints)
sign = '+' if (num.imag > 0) else '' # negative numbers already carry sign
# MMA treats trailing I as factor, not exponent
return _COMPLEX_FORMAT % (real, sign, imag)
def _get_mma_array(
array,
keep_symbols=_DEFAULT_KEEP_SYMBOLS,
precision=_DEFAULT_PRECISION,
keep_ints=_DEFAULT_KEEP_INTS,
single_line=True):
return _ARRAY_OUTER_FORMAT % _ARRAY_JOIN_FORMAT.join(
get_mma(item,
keep_symbols=keep_symbols,
precision=precision,
keep_ints=keep_ints,
single_line=single_line)
for item in array)
def _get_mma_dict(
dic,
key_order=None, # not passed on
keep_symbols=_DEFAULT_KEEP_SYMBOLS,
precision=_DEFAULT_PRECISION,
keep_ints=_DEFAULT_KEEP_INTS,
single_line=True):
items = []
for key in (key_order if key_order else dic):
items.append(
('' if single_line else _INDENT_CHAR) +
_DICT_ITEM_FORMAT % (
get_mma(key,
keep_symbols=keep_symbols,
precision=precision,
keep_ints=keep_ints,
single_line=True),
get_mma(dic[key],
keep_symbols=keep_symbols,
precision=precision,
keep_ints=keep_ints,
single_line=True)))
return (_DICT_OUTER_FORMAT % _DICT_JOIN_FORMAT.join(items)).replace(
'\n', ' ' if single_line else '\n')
def get_mma(
obj,
key_order=None, # not passed on
keep_symbols=_DEFAULT_KEEP_SYMBOLS,
precision=_DEFAULT_PRECISION,
keep_ints=_DEFAULT_KEEP_INTS,
single_line=False):
'''
Constructs a nested Mathematica expression from a python structure consisting
of strings, numbers (including complex and py2.X longs), lists, tuples and sets
(converted to lists) and dictionaries.
key_order: an explicit order for the keys in the passed dictionary
can only be given when the passed obj is a dictionary
keep_symbols: whether python strings should be converted to Mathematica symbols,
else enquoted to become strings
precision: the number of decimal digits in scientific notation format of numbers
keep_ints: whether integers should be kept formatted as such, else converted to
scientific notation (at the precision supplied); applies to the real
and imaginary components of complex numbers too
'''
# check parameters are valid
if key_order and not isinstance(obj, dict):
raise ValueError('key_order can only be used by an outer-most dictionary')
if key_order and any(key not in obj for key in key_order):
raise ValueError('key_order contains a key not present in supplied dictionary')
if key_order and any(key not in key_order for key in obj):
raise ValueError('key_order does not contain all keys in supplied dictionary')
# adjust object type
if isinstance(obj, (tuple, set)):
obj = list(obj)
# format based on type (may recurse back to get_mma)
if isinstance(obj, bool):
return _get_mma_bool(obj)
if isinstance(obj, str):
return _get_mma_string(obj,
keep_symbols=keep_symbols)
if isinstance(obj, (int, float)) or type(obj).__name__=='long':
return _get_mma_real(obj, precision=precision,
keep_ints=keep_ints)
if isinstance(obj, complex):
return _get_mma_complex(obj, precision=precision,
keep_ints=keep_ints)
if isinstance(obj, list):
return _get_mma_array(obj, precision=precision,
keep_ints=keep_ints,
keep_symbols=keep_symbols,
single_line=True)
if isinstance(obj, dict):
return _get_mma_dict(obj, key_order=key_order,
precision=precision,
keep_ints=keep_ints,
keep_symbols=keep_symbols,
single_line=single_line)
raise TypeError("Encountered unsupported type '%s'" % type(obj).__name__)
def save_as_mma(
obj, filename,
key_order=None, # not passed on
keep_symbols=_DEFAULT_KEEP_SYMBOLS,
precision=_DEFAULT_PRECISION,
keep_ints=_DEFAULT_KEEP_INTS):
'''
Constructs and writes to file a nested Mathematica expression from a python
structure consisting of strings, numbers (including complex and py2.X longs),
lists, tuples and sets (converted to lists) and dictionaries.
key_order: an explicit order for the keys in the passed dictionary
can only be given when the passed obj is a dictionary
keep_symbols: whether python strings should be converted to Mathematica symbols,
else enquoted to become strings
precision: the number of decimal digits in scientific notation format of numbers
keep_ints: whether integers should be kept formatted as such, else converted to
scientific notation (at the precision supplied); applies to the real
and imaginary components of complex numbers too
'''
# generate mma string
mma_str = get_mma(obj, key_order=key_order,
keep_symbols=keep_symbols,
precision=precision,
keep_ints=keep_ints)
# save it to file (ensure direc exists)
if ('/' in filename) or ('\\' in filename):
os.makedirs(os.path.dirname(filename), exist_ok=True)
with open(filename, 'w') as file:
file.write(mma_str)
def unit_tests():
# bools
assert get_mma(True) == 'True'
# strings
assert get_mma('a', keep_symbols=True) == 'a'
assert get_mma('a', keep_symbols=False) == '"a"'
# reals
assert get_mma(0.12345, precision=3) == '1.235*10^-01'
assert get_mma(50) == '50'
assert get_mma(50, keep_ints=False, precision=2) == '5.00*10^+01'
# complex
assert get_mma(3-4j, keep_ints=True) == '3-4I'
assert get_mma(3-4j, keep_ints=False, precision=2) == '3.00*10^+00-4.00*10^+00I'
# arrays
assert get_mma([1, 2, 3]) == '{1, 2, 3}'
assert get_mma([1, .1], precision=3) == '{1, 1.000*10^-01}'
assert get_mma([1, 2], precision=1, keep_ints=False) == '{1.0*10^+00, 2.0*10^+00}'
assert get_mma([1j], keep_ints=True) == '{0+1I}'
# tuples and sets
assert get_mma((1, 2, 3)) == get_mma(set([1,2,3])) == get_mma([1, 2, 3])
# dictionaries
assert (get_mma({'a':True, 'b':.5},
key_order=['b','a'],
keep_symbols=True,
precision=2)
== '<|\n'
' b -> 5.00*10^-01,\n'
' a -> True\n'
'|>')
assert (get_mma({'a':True, 'b':{'c':'d'}, 'e':[1, {'f':4}]},
key_order=['a', 'b', 'e'],
keep_symbols=True,
keep_ints=True)
== '<|\n'
' a -> True,\n'
' b -> <| c -> d |>,\n'
' e -> {1, <| f -> 4 |>}\n'
'|>')
# errors
try:
get_mma({'a':1, 'b':2}, key_order=['a'])
raise AssertionError
except ValueError:
pass
try:
get_mma({'a':1}, key_order=['a', 'b'])
raise AssertionError
except ValueError:
pass
if __name__ == '__main__':
unit_tests()