forked from ivy-llc/ivy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmultiversion_frontend_test.py
130 lines (103 loc) · 3.79 KB
/
multiversion_frontend_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from ivy_tests import config
import sys
import jsonpickle
import importlib
def available_frameworks():
available_frameworks_lis = ["numpy", "jax", "tensorflow", "torch"]
try:
import jax
assert jax, "jax is imported to see if the user has it installed"
except ImportError:
available_frameworks_lis.remove("jax")
try:
import tensorflow as tf
assert tf, "tensorflow is imported to see if the user has it installed"
except ImportError:
available_frameworks_lis.remove("tensorflow")
try:
import torch
assert torch, "torch is imported to see if the user has it installed"
except ImportError:
available_frameworks_lis.remove("torch")
return available_frameworks_lis
def convtrue(argument):
"""Convert NativeClass in argument to true framework counter part"""
if isinstance(argument, NativeClass):
return argument._native_class
return argument
class NativeClass:
"""
An empty class to represent a class that only exist in a specific framework.
Attributes
----------
_native_class : class reference
A reference to the framework-specific class.
"""
def __init__(self, native_class):
"""
Constructs the native class object.
Parameters
----------
native_class : class reference
A reperence to the framework-specific class being represented.
"""
self._native_class = native_class
if __name__ == "__main__":
arg_lis = sys.argv
fw_lis = []
for i in arg_lis[1:]:
if i.split("/")[0] == "jax":
fw_lis.append(i.split("/")[0] + "/" + i.split("/")[1])
fw_lis.append(i.split("/")[2] + "/" + i.split("/")[3])
else:
fw_lis.append(i)
config.allow_global_framework_imports(fw=fw_lis)
j = 1
import ivy
# ivy.bfloat16
ivy.set_backend(arg_lis[2].split("/")[0])
import numpy
while j:
try:
z = input()
pickle_dict = jsonpickle.loads(z)
frontend_fw = input()
frontend_fw = importlib.import_module(frontend_fw)
func = input()
args_np, kwargs_np = pickle_dict["a"], pickle_dict["b"]
args_frontend = ivy.nested_map(
args_np,
lambda x: ivy.native_array(x)
if isinstance(x, numpy.ndarray)
else ivy.as_native_dtype(x)
if isinstance(x, ivy.Dtype)
else x,
shallow=False,
)
kwargs_frontend = ivy.nested_map(
kwargs_np,
lambda x: ivy.native_array(x) if isinstance(x, numpy.ndarray) else x,
shallow=False,
)
# change ivy dtypes to native dtypes
if "dtype" in kwargs_frontend:
kwargs_frontend["dtype"] = ivy.as_native_dtype(kwargs_frontend["dtype"])
# change ivy device to native devices
if "device" in kwargs_frontend:
kwargs_frontend["device"] = ivy.as_native_dev(kwargs_frontend["device"])
# check and replace the NativeClass objects in arguments
# with true counterparts
args_frontend = ivy.nested_map(
args_frontend, fn=convtrue, include_derived=True, max_depth=10
)
kwargs_frontend = ivy.nested_map(
kwargs_frontend, fn=convtrue, include_derived=True, max_depth=10
)
frontend_ret = frontend_fw.__dict__[func](*args_frontend, **kwargs_frontend)
frontend_ret = ivy.to_numpy(frontend_ret)
frontend_ret = jsonpickle.dumps(frontend_ret)
print(frontend_ret)
except EOFError:
continue
except Exception as e:
raise e