Skip to content

Commit

Permalink
Factor gen_test_model to separate module
Browse files Browse the repository at this point in the history
This patch factors the _gen_test_model function in policy_saver_test
into a separate module in a new testing subdirectory to facilitate reuse
across tests. This is intended to be used in the regalloc_trace_worker
test which has to create a test model to test that everything works with
TFLite.

Pull Request: google#414
  • Loading branch information
boomanaiden154 committed Jan 12, 2025
1 parent f179ce1 commit f8a1386
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 51 deletions.
54 changes: 3 additions & 51 deletions compiler_opt/rl/policy_saver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,55 +26,7 @@
from tf_agents.trajectories import time_step

from compiler_opt.rl import policy_saver


# copied from the llvm regalloc generator
def _gen_test_model(outdir: str):
policy_decision_label = 'index_to_evict'
policy_output_spec = """
[
{
"logging_name": "index_to_evict",
"tensor_spec": {
"name": "StatefulPartitionedCall",
"port": 0,
"type": "int64_t",
"shape": [
1
]
}
}
]
"""
per_register_feature_list = ['mask']
num_registers = 33

def get_input_signature():
"""Returns (time_step_spec, action_spec) for LLVM register allocation."""
inputs = dict(
(key, tf.TensorSpec(dtype=tf.int64, shape=(num_registers), name=key))
for key in per_register_feature_list)
return inputs

module = tf.Module()
# We have to set this useless variable in order for the TF C API to correctly
# intake it
module.var = tf.Variable(0, dtype=tf.int64)

def action(*inputs):
result = tf.math.argmax(
tf.cast(inputs[0]['mask'], tf.int32), axis=-1) + module.var
return {policy_decision_label: result}

module.action = tf.function()(action)
action = {
'action': module.action.get_concrete_function(get_input_signature())
}
tf.saved_model.save(module, outdir, signatures=action)
output_spec_path = os.path.join(outdir, 'output_spec.json')
with tf.io.gfile.GFile(output_spec_path, 'w') as f:
print(f'Writing output spec to {output_spec_path}.')
f.write(policy_output_spec)
from compiler_opt.testing import model_test_utils


class PolicySaverTest(tf.test.TestCase):
Expand Down Expand Up @@ -135,7 +87,7 @@ def test_save_policy(self):
def test_tflite_conversion(self):
sm_dir = os.path.join(self.get_temp_dir(), 'saved_model')
tflite_dir = os.path.join(self.get_temp_dir(), 'tflite_model')
_gen_test_model(sm_dir)
model_test_utils.gen_test_model(sm_dir)
policy_saver.convert_mlgo_model(sm_dir, tflite_dir)
self.assertTrue(
tf.io.gfile.exists(
Expand All @@ -148,7 +100,7 @@ def test_policy_serialization(self):
sm_dir = os.path.join(self.get_temp_dir(), 'model')
orig_dir = os.path.join(self.get_temp_dir(), 'orig_model')
dest_dir = os.path.join(self.get_temp_dir(), 'dest_model')
_gen_test_model(sm_dir)
model_test_utils.gen_test_model(sm_dir)
policy_saver.convert_mlgo_model(sm_dir, orig_dir)

serialized_policy = policy_saver.Policy.from_filesystem(orig_dir)
Expand Down
67 changes: 67 additions & 0 deletions compiler_opt/testing/model_test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# coding=utf-8
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for running tests that involve tensorflow model.s"""

import os

import tensorflow as tf

# copied from the llvm regalloc generator
def gen_test_model(outdir: str):
policy_decision_label = 'index_to_evict'
policy_output_spec = """
[
{
"logging_name": "index_to_evict",
"tensor_spec": {
"name": "StatefulPartitionedCall",
"port": 0,
"type": "int64_t",
"shape": [
1
]
}
}
]
"""
per_register_feature_list = ['mask']
num_registers = 33

def get_input_signature():
"""Returns (time_step_spec, action_spec) for LLVM register allocation."""
inputs = dict(
(key, tf.TensorSpec(dtype=tf.int64, shape=(num_registers), name=key))
for key in per_register_feature_list)
return inputs

module = tf.Module()
# We have to set this useless variable in order for the TF C API to correctly
# intake it
module.var = tf.Variable(0, dtype=tf.int64)

def action(*inputs):
result = tf.math.argmax(
tf.cast(inputs[0]['mask'], tf.int32), axis=-1) + module.var
return {policy_decision_label: result}

module.action = tf.function()(action)
action = {
'action': module.action.get_concrete_function(get_input_signature())
}
tf.saved_model.save(module, outdir, signatures=action)
output_spec_path = os.path.join(outdir, 'output_spec.json')
with tf.io.gfile.GFile(output_spec_path, 'w') as f:
print(f'Writing output spec to {output_spec_path}.')
f.write(policy_output_spec)

0 comments on commit f8a1386

Please sign in to comment.