Import tensorflow

This commit is contained in:
2026-02-15 21:45:42 -08:00
parent f3e8b90764
commit c530630153
20524 changed files with 9017694 additions and 25 deletions
@@ -0,0 +1,352 @@
"""Python wrappers around TensorFlow ops.
This file is MACHINE GENERATED! Do not edit.
"""
import collections
from tensorflow.python import pywrap_tfe as pywrap_tfe
from tensorflow.python.eager import context as _context
from tensorflow.python.eager import core as _core
from tensorflow.python.eager import execute as _execute
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.security.fuzzing.py import annotation_types as _atypes
from tensorflow.python.framework import op_def_registry as _op_def_registry
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import op_def_library as _op_def_library
from tensorflow.python.util.deprecation import deprecated_endpoints
from tensorflow.python.util import dispatch as _dispatch
from tensorflow.python.util.tf_export import tf_export
from typing import TypeVar, List, Any
from typing_extensions import Annotated
TV_XlaClusterOutput_T = TypeVar("TV_XlaClusterOutput_T", "_atypes.BFloat16", "_atypes.Bool", "_atypes.Complex128", "_atypes.Complex64", "_atypes.Float16", "_atypes.Float32", "_atypes.Float64", "_atypes.Float8e4m3b11fnuz", "_atypes.Float8e4m3fn", "_atypes.Float8e4m3fnuz", "_atypes.Float8e5m2", "_atypes.Float8e5m2fnuz", "_atypes.Half", "_atypes.Int16", "_atypes.Int2", "_atypes.Int32", "_atypes.Int4", "_atypes.Int64", "_atypes.Int8", "_atypes.QInt16", "_atypes.QInt32", "_atypes.QInt8", "_atypes.QUInt16", "_atypes.QUInt8", "_atypes.Resource", "_atypes.String", "_atypes.UInt16", "_atypes.UInt2", "_atypes.UInt32", "_atypes.UInt4", "_atypes.UInt64", "_atypes.UInt8", "_atypes.Variant")
@_dispatch.add_fallback_dispatch_list
@_dispatch.add_type_based_api_dispatcher
@tf_export('xla_cluster_output')
def xla_cluster_output(input: Annotated[Any, TV_XlaClusterOutput_T], name=None) -> Annotated[Any, TV_XlaClusterOutput_T]:
r"""Operator that connects the output of an XLA computation to other consumer graph nodes.
Args:
input: A `Tensor`.
name: A name for the operation (optional).
Returns:
A `Tensor`. Has the same type as `input`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "XlaClusterOutput", name, input)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
_result = _dispatcher_for_xla_cluster_output(
(input, name,), None)
if _result is not NotImplemented:
return _result
return xla_cluster_output_eager_fallback(
input, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
except (TypeError, ValueError):
_result = _dispatch.dispatch(
xla_cluster_output, (), dict(input=input, name=name)
)
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
return _result
raise
else:
_result = _dispatcher_for_xla_cluster_output(
(input, name,), None)
if _result is not NotImplemented:
return _result
# Add nodes to the TensorFlow graph.
try:
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"XlaClusterOutput", input=input, name=name)
except (TypeError, ValueError):
_result = _dispatch.dispatch(
xla_cluster_output, (), dict(input=input, name=name)
)
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
return _result
raise
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("T", _op._get_attr_type("T"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"XlaClusterOutput", _inputs_flat, _attrs, _result)
_result, = _result
return _result
XlaClusterOutput = tf_export("raw_ops.XlaClusterOutput")(_ops.to_raw_op(xla_cluster_output))
_dispatcher_for_xla_cluster_output = xla_cluster_output._tf_type_based_dispatcher.Dispatch
def xla_cluster_output_eager_fallback(input: Annotated[Any, TV_XlaClusterOutput_T], name, ctx) -> Annotated[Any, TV_XlaClusterOutput_T]:
_attr_T, (input,) = _execute.args_to_matching_eager([input], ctx, [])
_inputs_flat = [input]
_attrs = ("T", _attr_T)
_result = _execute.execute(b"XlaClusterOutput", 1, inputs=_inputs_flat,
attrs=_attrs, ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"XlaClusterOutput", _inputs_flat, _attrs, _result)
_result, = _result
return _result
@_dispatch.add_fallback_dispatch_list
@_dispatch.add_type_based_api_dispatcher
@tf_export('xla_launch')
def xla_launch(constants, args, resources: Annotated[List[Any], _atypes.Resource], Tresults, function, name=None):
r"""XLA Launch Op. For use by the XLA JIT only.
Args:
constants: A list of `Tensor` objects.
args: A list of `Tensor` objects.
resources: A list of `Tensor` objects with type `resource`.
Tresults: A list of `tf.DTypes`.
function: A function decorated with @Defun.
name: A name for the operation (optional).
Returns:
A list of `Tensor` objects of type `Tresults`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "XlaLaunch", name, constants, args, resources, "Tresults",
Tresults, "function", function)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
_result = _dispatcher_for_xla_launch(
(constants, args, resources, Tresults, function, name,), None)
if _result is not NotImplemented:
return _result
return xla_launch_eager_fallback(
constants, args, resources, Tresults=Tresults, function=function,
name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
except (TypeError, ValueError):
_result = _dispatch.dispatch(
xla_launch, (), dict(constants=constants, args=args,
resources=resources, Tresults=Tresults,
function=function, name=name)
)
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
return _result
raise
else:
_result = _dispatcher_for_xla_launch(
(constants, args, resources, Tresults, function, name,), None)
if _result is not NotImplemented:
return _result
# Add nodes to the TensorFlow graph.
if not isinstance(resources, (list, tuple)):
raise TypeError(
"Expected list for 'resources' argument to "
"'xla_launch' Op, not %r." % resources)
_attr_Nresources = len(resources)
if not isinstance(Tresults, (list, tuple)):
raise TypeError(
"Expected list for 'Tresults' argument to "
"'xla_launch' Op, not %r." % Tresults)
Tresults = [_execute.make_type(_t, "Tresults") for _t in Tresults]
try:
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"XlaLaunch", constants=constants, args=args, resources=resources,
Tresults=Tresults, function=function, name=name)
except (TypeError, ValueError):
_result = _dispatch.dispatch(
xla_launch, (), dict(constants=constants, args=args,
resources=resources, Tresults=Tresults,
function=function, name=name)
)
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
return _result
raise
_result = _outputs[:]
if not _result:
return _op
if _execute.must_record_gradient():
_attrs = ("Tconstants", _op.get_attr("Tconstants"), "Targs",
_op.get_attr("Targs"), "Nresources",
_op._get_attr_int("Nresources"), "Tresults",
_op.get_attr("Tresults"), "function", _op.get_attr("function"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"XlaLaunch", _inputs_flat, _attrs, _result)
return _result
XlaLaunch = tf_export("raw_ops.XlaLaunch")(_ops.to_raw_op(xla_launch))
_dispatcher_for_xla_launch = xla_launch._tf_type_based_dispatcher.Dispatch
def xla_launch_eager_fallback(constants, args, resources: Annotated[List[Any], _atypes.Resource], Tresults, function, name, ctx):
if not isinstance(resources, (list, tuple)):
raise TypeError(
"Expected list for 'resources' argument to "
"'xla_launch' Op, not %r." % resources)
_attr_Nresources = len(resources)
if not isinstance(Tresults, (list, tuple)):
raise TypeError(
"Expected list for 'Tresults' argument to "
"'xla_launch' Op, not %r." % Tresults)
Tresults = [_execute.make_type(_t, "Tresults") for _t in Tresults]
_attr_Tconstants, constants = _execute.convert_to_mixed_eager_tensors(constants, ctx)
_attr_Targs, args = _execute.convert_to_mixed_eager_tensors(args, ctx)
resources = _ops.convert_n_to_tensor(resources, _dtypes.resource)
_inputs_flat = list(constants) + list(args) + list(resources)
_attrs = ("Tconstants", _attr_Tconstants, "Targs", _attr_Targs,
"Nresources", _attr_Nresources, "Tresults", Tresults, "function", function)
_result = _execute.execute(b"XlaLaunch", len(Tresults), inputs=_inputs_flat,
attrs=_attrs, ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"XlaLaunch", _inputs_flat, _attrs, _result)
return _result
@_dispatch.add_fallback_dispatch_list
@_dispatch.add_type_based_api_dispatcher
@tf_export('xla_launch_v2')
def xla_launch_v2(args, Tresults, constants, resources, function, name=None):
r"""XLA Launch Op. For use by the XLA JIT only.
Args:
args: A list of `Tensor` objects.
Tresults: A list of `tf.DTypes`.
constants: A list of `ints`.
resources: A list of `ints`.
function: A function decorated with @Defun.
name: A name for the operation (optional).
Returns:
A list of `Tensor` objects of type `Tresults`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "XlaLaunchV2", name, args, "Tresults", Tresults, "constants",
constants, "resources", resources, "function", function)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
_result = _dispatcher_for_xla_launch_v2(
(args, Tresults, constants, resources, function, name,), None)
if _result is not NotImplemented:
return _result
return xla_launch_v2_eager_fallback(
args, Tresults=Tresults, constants=constants, resources=resources,
function=function, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
except (TypeError, ValueError):
_result = _dispatch.dispatch(
xla_launch_v2, (), dict(args=args, Tresults=Tresults,
constants=constants, resources=resources,
function=function, name=name)
)
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
return _result
raise
else:
_result = _dispatcher_for_xla_launch_v2(
(args, Tresults, constants, resources, function, name,), None)
if _result is not NotImplemented:
return _result
# Add nodes to the TensorFlow graph.
if not isinstance(Tresults, (list, tuple)):
raise TypeError(
"Expected list for 'Tresults' argument to "
"'xla_launch_v2' Op, not %r." % Tresults)
Tresults = [_execute.make_type(_t, "Tresults") for _t in Tresults]
if not isinstance(constants, (list, tuple)):
raise TypeError(
"Expected list for 'constants' argument to "
"'xla_launch_v2' Op, not %r." % constants)
constants = [_execute.make_int(_i, "constants") for _i in constants]
if not isinstance(resources, (list, tuple)):
raise TypeError(
"Expected list for 'resources' argument to "
"'xla_launch_v2' Op, not %r." % resources)
resources = [_execute.make_int(_i, "resources") for _i in resources]
try:
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"XlaLaunchV2", args=args, Tresults=Tresults, constants=constants,
resources=resources, function=function, name=name)
except (TypeError, ValueError):
_result = _dispatch.dispatch(
xla_launch_v2, (), dict(args=args, Tresults=Tresults,
constants=constants, resources=resources,
function=function, name=name)
)
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
return _result
raise
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("Targs", _op.get_attr("Targs"), "Tresults",
_op.get_attr("Tresults"), "constants",
_op.get_attr("constants"), "resources",
_op.get_attr("resources"), "function", _op.get_attr("function"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"XlaLaunchV2", _inputs_flat, _attrs, _result)
return _result
XlaLaunchV2 = tf_export("raw_ops.XlaLaunchV2")(_ops.to_raw_op(xla_launch_v2))
_dispatcher_for_xla_launch_v2 = xla_launch_v2._tf_type_based_dispatcher.Dispatch
def xla_launch_v2_eager_fallback(args, Tresults, constants, resources, function, name, ctx):
if not isinstance(Tresults, (list, tuple)):
raise TypeError(
"Expected list for 'Tresults' argument to "
"'xla_launch_v2' Op, not %r." % Tresults)
Tresults = [_execute.make_type(_t, "Tresults") for _t in Tresults]
if not isinstance(constants, (list, tuple)):
raise TypeError(
"Expected list for 'constants' argument to "
"'xla_launch_v2' Op, not %r." % constants)
constants = [_execute.make_int(_i, "constants") for _i in constants]
if not isinstance(resources, (list, tuple)):
raise TypeError(
"Expected list for 'resources' argument to "
"'xla_launch_v2' Op, not %r." % resources)
resources = [_execute.make_int(_i, "resources") for _i in resources]
_attr_Targs, args = _execute.convert_to_mixed_eager_tensors(args, ctx)
_inputs_flat = list(args)
_attrs = ("Targs", _attr_Targs, "Tresults", Tresults, "constants",
constants, "resources", resources, "function", function)
_result = _execute.execute(b"XlaLaunchV2", len(Tresults),
inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"XlaLaunchV2", _inputs_flat, _attrs, _result)
return _result
@@ -0,0 +1,25 @@
"""Gradients for XLA ops."""
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from tensorflow.python.framework import ops
@ops.RegisterGradient("XlaClusterOutput")
def _XlaClusterOutputGrad(_, grad):
del grad # unused
raise RuntimeError("Gradient computation of graph in xla.compile() is "
"prohibited because it can cause performance degradation."
"Please move gradient computation inside xla.compile().")
File diff suppressed because one or more lines are too long
@@ -0,0 +1,36 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: tensorflow/compiler/mlir/lite/debug/debug_options.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'tensorflow/compiler/mlir/lite/debug/debug_options.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n7tensorflow/compiler/mlir/lite/debug/debug_options.proto\x12\x14tensorflow.converter\"\x84\x02\n\x0c\x44\x65\x62ugOptions\x12\x15\n\x0bir_dump_dir\x18\x01 \x01(\t:\x00\x12\x1e\n\x12ir_dump_pass_regex\x18\x02 \x01(\t:\x02.*\x12\x1e\n\x12ir_dump_func_regex\x18\x03 \x01(\t:\x02.*\x12\x1c\n\renable_timing\x18\x04 \x01(\x08:\x05\x66\x61lse\x12\x19\n\x0fprint_ir_before\x18\x05 \x01(\t:\x00\x12\x18\n\x0eprint_ir_after\x18\x06 \x01(\t:\x00\x12#\n\x15print_ir_module_scope\x18\x07 \x01(\x08:\x04true\x12%\n\x1d\x65lide_elementsattrs_if_larger\x18\x08 \x01(\x03')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.compiler.mlir.lite.debug.debug_options_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_DEBUGOPTIONS']._serialized_start=82
_globals['_DEBUGOPTIONS']._serialized_end=342
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: tensorflow/compiler/mlir/lite/metrics/converter_error_data.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'tensorflow/compiler/mlir/lite/metrics/converter_error_data.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n@tensorflow/compiler/mlir/lite/metrics/converter_error_data.proto\x12\x0etflite.metrics\"\xdc\x06\n\x12\x43onverterErrorData\x12\x11\n\tcomponent\x18\x01 \x01(\t\x12\x14\n\x0csubcomponent\x18\x02 \x01(\t\x12@\n\nerror_code\x18\x03 \x01(\x0e\x32,.tflite.metrics.ConverterErrorData.ErrorCode\x12\x15\n\rerror_message\x18\x04 \x01(\t\x12=\n\x08operator\x18\x05 \x01(\x0b\x32+.tflite.metrics.ConverterErrorData.Operator\x12=\n\x08location\x18\x06 \x01(\x0b\x32+.tflite.metrics.ConverterErrorData.Location\x1a\x18\n\x08Operator\x12\x0c\n\x04name\x18\x01 \x01(\t\x1a\x39\n\x07\x46ileLoc\x12\x10\n\x08\x66ilename\x18\x01 \x01(\t\x12\x0c\n\x04line\x18\x02 \x01(\r\x12\x0e\n\x06\x63olumn\x18\x03 \x01(\r\x1aU\n\tSourceLoc\x12\x0c\n\x04name\x18\x01 \x01(\t\x12:\n\x06source\x18\x02 \x01(\x0b\x32*.tflite.metrics.ConverterErrorData.FileLoc\x1a\x85\x01\n\x08Location\x12=\n\x04type\x18\x01 \x01(\x0e\x32/.tflite.metrics.ConverterErrorData.LocationType\x12:\n\x04\x63\x61ll\x18\x02 \x03(\x0b\x32,.tflite.metrics.ConverterErrorData.SourceLoc\"\xc5\x01\n\tErrorCode\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x18\n\x14\x45RROR_NEEDS_FLEX_OPS\x10\x01\x12\x1a\n\x16\x45RROR_NEEDS_CUSTOM_OPS\x10\x02\x12%\n!ERROR_UNSUPPORTED_CONTROL_FLOW_V1\x10\x03\x12/\n+ERROR_STATEFUL_PARTITIONED_CALL_IN_FINAL_IR\x10\x04\x12\x1d\n\x18\x45RROR_GPU_NOT_COMPATIBLE\x10\xc8\x01\"J\n\x0cLocationType\x12\x0e\n\nUNKNOWNLOC\x10\x00\x12\x0b\n\x07NAMELOC\x10\x01\x12\x0f\n\x0b\x43\x41LLSITELOC\x10\x02\x12\x0c\n\x08\x46USEDLOC\x10\x03')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.compiler.mlir.lite.metrics.converter_error_data_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_CONVERTERERRORDATA']._serialized_start=85
_globals['_CONVERTERERRORDATA']._serialized_end=945
_globals['_CONVERTERERRORDATA_OPERATOR']._serialized_start=363
_globals['_CONVERTERERRORDATA_OPERATOR']._serialized_end=387
_globals['_CONVERTERERRORDATA_FILELOC']._serialized_start=389
_globals['_CONVERTERERRORDATA_FILELOC']._serialized_end=446
_globals['_CONVERTERERRORDATA_SOURCELOC']._serialized_start=448
_globals['_CONVERTERERRORDATA_SOURCELOC']._serialized_end=533
_globals['_CONVERTERERRORDATA_LOCATION']._serialized_start=536
_globals['_CONVERTERERRORDATA_LOCATION']._serialized_end=669
_globals['_CONVERTERERRORDATA_ERRORCODE']._serialized_start=672
_globals['_CONVERTERERRORDATA_ERRORCODE']._serialized_end=869
_globals['_CONVERTERERRORDATA_LOCATIONTYPE']._serialized_start=871
_globals['_CONVERTERERRORDATA_LOCATIONTYPE']._serialized_end=945
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,51 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: tensorflow/compiler/mlir/lite/model_flags.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'tensorflow/compiler/mlir/lite/model_flags.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from tensorflow.compiler.mlir.lite import types_pb2 as tensorflow_dot_compiler_dot_mlir_dot_lite_dot_types__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n/tensorflow/compiler/mlir/lite/model_flags.proto\x12\x06tflite\x1a)tensorflow/compiler/mlir/lite/types.proto\"5\n\x0fInputArrayShape\x12\x0c\n\x04\x64ims\x18\x02 \x03(\x05\x12\x14\n\x0cunknown_rank\x18\x03 \x01(\x08\"\x93\x01\n\nInputArray\x12\x0c\n\x04name\x18\x01 \x01(\t\x12&\n\x05shape\x18\x06 \x01(\x0b\x32\x17.tflite.InputArrayShape\x12\x12\n\nmean_value\x18\x03 \x01(\x02\x12\x14\n\tstd_value\x18\x04 \x01(\x02:\x01\x31\x12%\n\tdata_type\x18\x05 \x01(\x0e\x32\x12.tflite.IODataType\"t\n\x08RnnState\x12\x13\n\x0bstate_array\x18\x01 \x01(\t\x12\x1e\n\x16\x62\x61\x63k_edge_source_array\x18\x02 \x01(\t\x12\x13\n\x0b\x64iscardable\x18\x05 \x01(\x08\x12\x0c\n\x04size\x18\x03 \x01(\x05\x12\x10\n\x08num_dims\x18\x04 \x01(\x05\"\xf5\x01\n\x0f\x41rraysExtraInfo\x12.\n\x07\x65ntries\x18\x01 \x03(\x0b\x32\x1d.tflite.ArraysExtraInfo.Entry\x1a\xb1\x01\n\x05\x45ntry\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x13\n\x0bname_regexp\x18\x07 \x01(\t\x12\x0b\n\x03min\x18\x02 \x01(\x01\x12\x0b\n\x03max\x18\x03 \x01(\x01\x12%\n\tdata_type\x18\x04 \x01(\x0e\x32\x12.tflite.IODataType\x12&\n\x05shape\x18\x05 \x01(\x0b\x32\x17.tflite.InputArrayShape\x12\x1c\n\x14\x63onstant_float_value\x18\x06 \x01(\x02\"\xd0\x05\n\nModelFlags\x12(\n\x0cinput_arrays\x18\x01 \x03(\x0b\x32\x12.tflite.InputArray\x12\x15\n\routput_arrays\x18\x02 \x03(\t\x12\x1d\n\x15\x63ontrol_output_arrays\x18\x18 \x03(\t\x12\x16\n\x0evariable_batch\x18\n \x01(\x08\x12$\n\nrnn_states\x18\x0c \x03(\x0b\x32\x10.tflite.RnnState\x12\x33\n\x0cmodel_checks\x18\x0e \x03(\x0b\x32\x1d.tflite.ModelFlags.ModelCheck\x12 \n\x18\x61llow_nonexistent_arrays\x18\x10 \x01(\x08\x12\x1d\n\x15\x61llow_nonascii_arrays\x18\x11 \x01(\x08\x12\x32\n\x11\x61rrays_extra_info\x18\x12 \x01(\x0b\x32\x17.tflite.ArraysExtraInfo\x12(\n\x1a\x63hange_concat_input_ranges\x18\x13 \x01(\x08:\x04true\x12\x17\n\x0fsaved_model_dir\x18\x14 \x01(\t\x12\x1b\n\x13saved_model_version\x18\x15 \x01(\x05\x12\x18\n\x10saved_model_tags\x18\x16 \x03(\t\x12\"\n\x1asaved_model_exported_names\x18\x17 \x03(\t\x12\x16\n\x0euse_hlo_import\x18\x19 \x01(\x08\x12\x35\n\rhlo_file_type\x18\x1a \x01(\x0e\x32\x1e.tflite.ModelFlags.HloFileType\x1aT\n\nModelCheck\x12\x18\n\ncount_type\x18\x01 \x01(\t:\x04None\x12\x15\n\tcount_min\x18\x02 \x01(\x05:\x02-1\x12\x15\n\tcount_max\x18\x03 \x01(\x05:\x02-1\"7\n\x0bHloFileType\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x0c\n\x08HLO_TEXT\x10\x01\x12\r\n\tHLO_PROTO\x10\x02')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.compiler.mlir.lite.model_flags_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_INPUTARRAYSHAPE']._serialized_start=102
_globals['_INPUTARRAYSHAPE']._serialized_end=155
_globals['_INPUTARRAY']._serialized_start=158
_globals['_INPUTARRAY']._serialized_end=305
_globals['_RNNSTATE']._serialized_start=307
_globals['_RNNSTATE']._serialized_end=423
_globals['_ARRAYSEXTRAINFO']._serialized_start=426
_globals['_ARRAYSEXTRAINFO']._serialized_end=671
_globals['_ARRAYSEXTRAINFO_ENTRY']._serialized_start=494
_globals['_ARRAYSEXTRAINFO_ENTRY']._serialized_end=671
_globals['_MODELFLAGS']._serialized_start=674
_globals['_MODELFLAGS']._serialized_end=1394
_globals['_MODELFLAGS_MODELCHECK']._serialized_start=1253
_globals['_MODELFLAGS_MODELCHECK']._serialized_end=1337
_globals['_MODELFLAGS_HLOFILETYPE']._serialized_start=1339
_globals['_MODELFLAGS_HLOFILETYPE']._serialized_end=1394
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,21 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
def Convert(model_flags_proto_txt_raw: object, converter_flags_proto_txt_raw: object, input_contents_txt_raw: object, extended_return: bool = ..., debug_info_txt_raw: object = ..., quantization_py_function_library=...) -> object: ...
def ExperimentalMlirQuantizeModel(input_contents_txt_raw: object, disable_per_channel: bool = ..., fully_quantize: bool = ..., inference_type: int = ..., input_data_type: int = ..., output_data_type: int = ..., enable_numeric_verify: bool = ..., enable_whole_model_verify: bool = ..., op_blocklist: object = ..., node_blocklist: object = ..., enable_variable_quantization: bool = ..., disable_per_channel_for_dense_layers: bool = ..., debug_options_proto_txt_raw: object = ...) -> object: ...
def ExperimentalMlirSparsifyModel(input_contents_txt_raw: object) -> object: ...
def FlatBufferToMlir(arg0: str, arg1: bool) -> str: ...
def RegisterCustomOpdefs(custom_opdefs_txt_raw: object) -> object: ...
def RetrieveCollectedErrors() -> list: ...
@@ -0,0 +1,90 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Wraps TFLite Converter interface with python lazy loader."""
# We need to import pywrap_tensorflow prior to the converter wrapper.
# pylint: disable=invalid-import-order,g-bad-import-order
from tensorflow.python import pywrap_tensorflow # pylint: disable=unused-import
from tensorflow.compiler.mlir.lite.python import _pywrap_converter_api
from tensorflow.compiler.mlir.quantization.tensorflow.python import py_function_lib
def wrapped_convert(
model_flags_str,
converter_flags_str,
input_data_str,
debug_info_str,
):
"""Wraps Convert with lazy loader."""
return _pywrap_converter_api.Convert(
model_flags_str,
converter_flags_str,
input_data_str,
False, # extended_return
debug_info_str,
py_function_lib.PyFunctionLibrary(),
)
def wrapped_experimental_mlir_quantize(
input_data_str,
disable_per_channel,
fully_quantize,
inference_type,
input_data_type,
output_data_type,
enable_numeric_verify,
enable_whole_model_verify,
denylisted_ops,
denylisted_nodes,
enable_variable_quantization,
disable_per_channel_for_dense_layers,
debug_options_str,
):
"""Wraps experimental mlir quantize model."""
return _pywrap_converter_api.ExperimentalMlirQuantizeModel(
input_data_str,
disable_per_channel,
fully_quantize,
inference_type,
input_data_type,
output_data_type,
enable_numeric_verify,
enable_whole_model_verify,
denylisted_ops,
denylisted_nodes,
enable_variable_quantization,
disable_per_channel_for_dense_layers,
debug_options_str,
)
def wrapped_experimental_mlir_sparsify(input_data_str):
"""Wraps experimental mlir sparsify model."""
return _pywrap_converter_api.ExperimentalMlirSparsifyModel(input_data_str)
def wrapped_register_custom_opdefs(custom_opdefs_list):
"""Wraps RegisterCustomOpdefs with lazy loader."""
return _pywrap_converter_api.RegisterCustomOpdefs(custom_opdefs_list)
def wrapped_retrieve_collected_errors():
"""Wraps RetrieveCollectedErrors with lazy loader."""
return _pywrap_converter_api.RetrieveCollectedErrors()
def wrapped_flat_buffer_file_to_mlir(model, input_is_filepath):
"""Wraps FlatBufferFileToMlir with lazy loader."""
return _pywrap_converter_api.FlatBufferToMlir(model, input_is_filepath)
@@ -0,0 +1,36 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: tensorflow/compiler/mlir/lite/types.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'tensorflow/compiler/mlir/lite/types.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n)tensorflow/compiler/mlir/lite/types.proto\x12\x06tflite*\xb3\x02\n\nIODataType\x12\x18\n\x14IO_DATA_TYPE_UNKNOWN\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\x13\n\x0fQUANTIZED_UINT8\x10\x02\x12\t\n\x05INT32\x10\x03\x12\t\n\x05INT64\x10\x04\x12\n\n\x06STRING\x10\x05\x12\x13\n\x0fQUANTIZED_INT16\x10\x06\x12\x08\n\x04\x42OOL\x10\x07\x12\r\n\tCOMPLEX64\x10\x08\x12\x12\n\x0eQUANTIZED_INT8\x10\t\x12\x0b\n\x07\x46LOAT16\x10\n\x12\x0b\n\x07\x46LOAT64\x10\x0b\x12\x0e\n\nCOMPLEX128\x10\x0c\x12\n\n\x06UINT64\x10\r\x12\x0c\n\x08RESOURCE\x10\x0e\x12\x0b\n\x07VARIANT\x10\x0f\x12\n\n\x06UINT32\x10\x10\x12\t\n\x05UINT8\x10\x11\x12\x08\n\x04INT8\x10\x12\x12\t\n\x05INT16\x10\x13\x12\n\n\x06UINT16\x10\x14')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.compiler.mlir.lite.types_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_IODATATYPE']._serialized_start=54
_globals['_IODATATYPE']._serialized_end=361
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\nJtensorflow/compiler/mlir/quantization/stablehlo/quantization_options.proto\x12\x16stablehlo.quantization\"^\n\x13QuantizationOptions\x12G\n\x13quantization_method\x18\x01 \x01(\x0b\x32*.stablehlo.quantization.QuantizationMethod\"\xdb\x01\n\x12QuantizationMethod\x12V\n\x1apreset_quantization_method\x18\x01 \x01(\x0b\x32\x30.stablehlo.quantization.PresetQuantizationMethodH\x00\x12V\n\x1a\x63ustom_quantization_method\x18\x02 \x01(\x0b\x32\x30.stablehlo.quantization.CustomQuantizationMethodH\x00\x42\x15\n\x13quantization_method\"\x92\x02\n\x18PresetQuantizationMethod\x12T\n\rpreset_method\x18\x01 \x01(\x0e\x32=.stablehlo.quantization.PresetQuantizationMethod.PresetMethod\"\x9f\x01\n\x0cPresetMethod\x12\x16\n\x12METHOD_UNSPECIFIED\x10\x00\x12\x0f\n\x0bWEIGHT_ONLY\x10\x01\x12,\n(POST_TRAINING_QUANTIZATION_DYNAMIC_RANGE\x10\x02\x12\x0b\n\x07\x46LOAT16\x10\x03\x12+\n\'POST_TRAINING_QUANTIZATION_STATIC_RANGE\x10\x04\"r\n\x18\x43ustomQuantizationMethod\x12V\n\x1bquantization_component_spec\x18\x01 \x03(\x0b\x32\x31.stablehlo.quantization.QuantizationComponentSpec\"\xc5\x05\n\x19QuantizationComponentSpec\x12g\n\x16quantization_component\x18\x01 \x01(\x0e\x32G.stablehlo.quantization.QuantizationComponentSpec.QuantizationComponent\x12M\n\tbit_width\x18\x02 \x01(\x0e\x32:.stablehlo.quantization.QuantizationComponentSpec.BitWidth\x12K\n\x08\x62it_type\x18\x03 \x01(\x0e\x32\x39.stablehlo.quantization.QuantizationComponentSpec.BitType\x12\x1b\n\x13\x65nable_narrow_range\x18\x04 \x01(\x08\x12\'\n\x1f\x65nable_per_channel_quantization\x18\x05 \x01(\x08\x12\x18\n\x10\x65nable_symmetric\x18\x06 \x01(\x08\"v\n\x15QuantizationComponent\x12\x19\n\x15\x43OMPONENT_UNSPECIFIED\x10\x00\x12\x18\n\x14\x43OMPONENT_ACTIVATION\x10\x01\x12\x14\n\x10\x43OMPONENT_WEIGHT\x10\x02\x12\x12\n\x0e\x43OMPONENT_BIAS\x10\x03\"k\n\x08\x42itWidth\x12\x19\n\x15\x42IT_WIDTH_UNSPECIFIED\x10\x00\x12\x0f\n\x0b\x42IT_WIDTH_4\x10\x01\x12\x0f\n\x0b\x42IT_WIDTH_8\x10\x02\x12\x10\n\x0c\x42IT_WIDTH_16\x10\x03\x12\x10\n\x0c\x42IT_WIDTH_32\x10\x04\"^\n\x07\x42itType\x12\x18\n\x14\x42IT_TYPE_UNSPECIFIED\x10\x00\x12\x10\n\x0c\x42IT_TYPE_INT\x10\x01\x12\x12\n\x0e\x42IT_TYPE_FLOAT\x10\x02\x12\x13\n\x0f\x42IT_TYPE_BFLOAT\x10\x03\x42\x03\xf8\x01\x01\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.compiler.mlir.quantization.stablehlo.quantization_options_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
_globals['DESCRIPTOR']._loaded_options = None
_globals['DESCRIPTOR']._serialized_options = b'\370\001\001'
_globals['_QUANTIZATIONOPTIONS']._serialized_start=102
_globals['_QUANTIZATIONOPTIONS']._serialized_end=196
_globals['_QUANTIZATIONMETHOD']._serialized_start=199
_globals['_QUANTIZATIONMETHOD']._serialized_end=418
_globals['_PRESETQUANTIZATIONMETHOD']._serialized_start=421
_globals['_PRESETQUANTIZATIONMETHOD']._serialized_end=695
_globals['_PRESETQUANTIZATIONMETHOD_PRESETMETHOD']._serialized_start=536
_globals['_PRESETQUANTIZATIONMETHOD_PRESETMETHOD']._serialized_end=695
_globals['_CUSTOMQUANTIZATIONMETHOD']._serialized_start=697
_globals['_CUSTOMQUANTIZATIONMETHOD']._serialized_end=811
_globals['_QUANTIZATIONCOMPONENTSPEC']._serialized_start=814
_globals['_QUANTIZATIONCOMPONENTSPEC']._serialized_end=1523
_globals['_QUANTIZATIONCOMPONENTSPEC_QUANTIZATIONCOMPONENT']._serialized_start=1200
_globals['_QUANTIZATIONCOMPONENTSPEC_QUANTIZATIONCOMPONENT']._serialized_end=1318
_globals['_QUANTIZATIONCOMPONENTSPEC_BITWIDTH']._serialized_start=1320
_globals['_QUANTIZATIONCOMPONENTSPEC_BITWIDTH']._serialized_end=1427
_globals['_QUANTIZATIONCOMPONENTSPEC_BITTYPE']._serialized_start=1429
_globals['_QUANTIZATIONCOMPONENTSPEC_BITTYPE']._serialized_end=1523
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,395 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines CalibrationAlgorithm for calculating min and max values calculated by calibration method."""
import abc
import itertools
import logging
import numpy as np
from tensorflow.compiler.mlir.quantization.stablehlo import quantization_config_pb2 as stablehlo_quant_config_pb2
from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_statistics_pb2 as calib_stats_pb2
_CalibrationMethod = (
stablehlo_quant_config_pb2.CalibrationOptions.CalibrationMethod
)
_REGISTRY = {}
def _implements(calib_method: _CalibrationMethod):
def decorator(cls):
assert calib_method not in _REGISTRY
_REGISTRY[calib_method] = cls
return cls
return decorator
class _CalibrationAlgorithmBase(abc.ABC):
"""Abstract base class for calibration algorithm."""
def __init__(
self,
statistics: calib_stats_pb2.CalibrationStatistics,
calib_opts: stablehlo_quant_config_pb2.CalibrationOptions,
):
self._statistics = statistics
self._calib_opts = calib_opts
@abc.abstractmethod
def get_min_max_value(self) -> tuple[float, float]:
pass
class _HistogramCalibrationAlgorithmBase(_CalibrationAlgorithmBase):
"""Base class for histogram calibrators."""
def __init__(
self,
statistics: calib_stats_pb2.CalibrationStatistics,
calib_opts: stablehlo_quant_config_pb2.CalibrationOptions,
):
"""Builds histogram using statistics.histogram_statistics.
lower_bound hist_mid
v v
|=========|=========|=========|=========|=========|
bin width
Args:
statistics: Collected calibration statistics.
calib_opts: Calibration options used for calculating min and max.
"""
super().__init__(statistics, calib_opts)
hist_stats = statistics.histogram_statistics
self._bin_width = hist_stats.bin_width
self._lower_bound = hist_stats.lower_bound
self._hist_freq = np.array(hist_stats.hist_freq)
self._num_bins = len(self._hist_freq)
self._num_bits = 8
# i-th bin has a range [bins[i], bins[i + 1]).
# bins[i] = lower_bound + i * bin_width
# bins[i + 1] = lower_bound + (i + 1) * bin_width
# So hist_mids[i] = (lower_bound + bin_width / 2) + bin_width * i
first_mid = self._lower_bound + self._bin_width / 2
last_mid = first_mid + (self._num_bins - 1) * self._bin_width
self._hist_mids = np.linspace(first_mid, last_mid, self._num_bins)
def _get_dequantized_hist_mids_after_quantize(
self, quant_min: float, quant_max: float
) -> np.ndarray:
"""Quantizes and dequantizes hist_mids using quant_min and quant_max.
Quantization converts the range of numbers from [quant_min, quant_max] to
[0, 2^num_bits - 1]. Values less than quant_min are converted to 0, and
values greater than quant_max are converted to 2^num_bits - 1.
The histogram represents the distribution of the data, and our goal is to
find the quant_min and quant_max that best describe this distribution. To do
this, we quantize hist_mids using quant_min and quant_max and dequantize
them again. Then the difference between hist_mids and dequantized hist_mids
equates to quantization error when using quant_min and quant_max.
Args:
quant_min: The minimum real value that can be represented by a quantized
value.
quant_max: The maximum real value that can be represented by a quantized
value.
Returns:
dequantized hist_mids after quantizing by quant_min and quant_max
"""
maxbound = 2**self._num_bits - 1
minbound = 0
scale = (quant_max - quant_min) / maxbound
zero_point = -quant_min / scale
# Limit the range of zero_point and scale in case (quant_max - quant_min)
# is unusually small.
if abs(zero_point) > 9e9:
zero_point = 9e9
if abs(scale) < 1e-9:
scale = 1e-9
zero_point = round(zero_point)
quantized_hist_mids = np.clip(
np.round(self._hist_mids / scale) + zero_point, minbound, maxbound
)
dequantized_hist_mids = scale * (quantized_hist_mids - zero_point)
return dequantized_hist_mids
def _get_weighted_mean_squared_error(
self, quant_min, quant_max
) -> tuple[float, float, float]:
"""Gets mean squared error between hist_mids and dequantized hist_mids.
Quantization converts the range of numbers from [quant_min, quant_max] to
[0, 2^num_bits - 1]. Values less than quant_min are converted to 0, and
values greater than quant_max are converted to 2^num_bits - 1.
Args:
quant_min: The minimum real value that can be represented by a quantized
value.
quant_max: The maximum real value that can be represented by a quantized
value.
Returns:
(error, quant_min, quant_max): Tuple of weighted mean squared error.
error = (hist_mids - dequantized_hist_mids)**2 * hist_freq
"""
dequantized_hist_mids = self._get_dequantized_hist_mids_after_quantize(
quant_min, quant_max
)
squared_error = (self._hist_mids - dequantized_hist_mids) ** 2
weighted_error = np.sum(squared_error * self._hist_freq)
return (weighted_error, quant_min, quant_max)
def _get_min_max_value_by_expanding_range(
self, start_idx: int
) -> tuple[float, float]:
"""Starting from start_idx, expand left and right alternately to find the min value of mse loss.
Args:
start_idx: Index to start quantization.
Returns:
(min_value, max_value): Min and max calculated.
"""
# Tuple of (mse_error, quant_min, quant_max).
mse_min = (float('inf'), float('inf'), float('inf'))
left, right = start_idx, start_idx
# If this value is true, it moves left, otherwise it moves right.
move_left = True
while not (left == 0 and right == self._num_bins - 1):
# Decrease left if right can't be moved or move_left is true.
if (move_left and left > 0) or (right == self._num_bins - 1):
left = max(left - 1, 0)
# Else increase right.
else:
right = min(right + 1, self._num_bins - 1)
# Toogle the move_left.
move_left = not move_left
quant_min, quant_max = self._hist_mids[left], self._hist_mids[right]
mse_tuple = self._get_weighted_mean_squared_error(quant_min, quant_max)
mse_min = min(mse_tuple, mse_min)
# Extract (quant_min, quant_max) from (mse_error, quant_min, quant_max).
min_value, max_value = mse_min[1], mse_min[2]
return min_value, max_value
@_implements(_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX)
class _MinMax(_CalibrationAlgorithmBase):
"""MinMaxCalibrationAlgorithm for calculating min and max values of calibration result.
MinMax calibration calculates the global min and global max values.
global min = min of given sample inputs
global max = max of given sample inputs
"""
def get_min_max_value(self) -> tuple[float, float]:
"""Calculates the global min and max values.
Returns:
(min_value, max_value): Min and max calculated using MinMax
"""
return (
self._statistics.min_max_statistics.global_min,
self._statistics.min_max_statistics.global_max,
)
@_implements(_CalibrationMethod.CALIBRATION_METHOD_AVERAGE_MIN_MAX)
class _AverageMinMax(_CalibrationAlgorithmBase):
"""AverageMinMaxCalibrationAlgorithm for calculating min and max values of calibration result.
AverageMinMax calibration calculates the average of min and max values.
average of min = sum of min values / number of samples
average of max = sum of max values / number of samples
"""
def get_min_max_value(self) -> tuple[float, float]:
"""Calculates the average of min and max values.
Returns:
(min_value, max_value): Min and max calculated using AverageMinMax
Raises:
ValueError: num_samples is 0.
"""
average_min_max_statistics = self._statistics.average_min_max_statistics
# num_samples is guaranteed to be larger than 0 because
# get_statistics_from_calibrator throws an exception if num_samples == 0.
num_samples = average_min_max_statistics.num_samples
if num_samples == 0:
raise ValueError(
'num_samples must not be 0 when calibration method is'
f' AverageMinMax: {self._calib_opts}'
)
min_value, max_value = (
average_min_max_statistics.min_sum / num_samples,
average_min_max_statistics.max_sum / num_samples,
)
return min_value, max_value
@_implements(_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_PERCENTILE)
class _HistogramPercentile(_HistogramCalibrationAlgorithmBase):
"""HistogramPercentile for calculating min and max values of calibration result."""
def get_min_max_value(self) -> tuple[float, float]:
"""Calculates min and max from statistics using calibration options.
A "percentile" is a statistical concept that represents the value below
which a given percentage of data falls in a dataset. It involves sorting the
data from smallest to largest and then finding the value at a specified
percentage position. For example, the 0.01 percentile represents the value
in a given data set that corresponds to the lowest 0.01% of the data.
HistogramPercentile calibration uses min_percentile and max_percentile to
find min and max.
min_percentile and max_percentile must be in range [0, 100].
min_percentile is 0.001 by default.
max_percentile is 99.999 by default.
Returns:
(min_value, max_value): Min and max calculated using HistogramPercentile
"""
total_freq = sum(self._hist_freq)
# hist_freq_cumsum is dividing cumulative sum of hist_freq by total_freq
# hist_freq_cumsum's value is in range [0, 1] by its definition
hist_freq_cumsum = np.cumsum(self._hist_freq) / total_freq
# min_percentile and max_percentile are converted from [0, 100] to [0, 1].
min_quantile, max_quantile = (
self._calib_opts.calibration_parameters.min_percentile / 100.0,
self._calib_opts.calibration_parameters.max_percentile / 100.0,
)
# Get index of min/max quantile.
min_quantile_idx, max_quantile_idx = (
np.searchsorted(hist_freq_cumsum, min_quantile, side='right'),
np.searchsorted(hist_freq_cumsum, max_quantile, side='left'),
)
# Get value of min/max quantile index.
min_value, max_value = (
self._hist_mids[min_quantile_idx],
self._hist_mids[max_quantile_idx],
)
return min_value, max_value
@_implements(_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE)
class _HistogramMseBruteforce(_HistogramCalibrationAlgorithmBase):
"""HistogramMseBruteforce for calculating min and max values of calibration result."""
def get_min_max_value(self) -> tuple[float, float]:
"""Finds the optimal quant_min and quant_max by testing all possible cases.
It guarantees optimal quant_min and quant_max for the representative
dataset, but not for the test dataset.
Returns:
(min_value, max_value): Min and max calculated using
HistogramMseBruteforce.
"""
if self._num_bins > 512:
logging.warning(
'num_bins=%d is too large. The HISTOGRAM_MSE_BRUTEFORCE method tests'
' all histogram mid value pairs, so it may take a long time.',
self._num_bins,
)
# Tuple of (mse_error, quant_min, quant_max).
mse_min = (float('inf'), float('inf'), float('inf'))
# Calculate the error for all hist_mid pairs.
for left, right in itertools.combinations(range(self._num_bins), 2):
quant_min, quant_max = self._hist_mids[left], self._hist_mids[right]
mse_tuple = self._get_weighted_mean_squared_error(quant_min, quant_max)
mse_min = min(mse_tuple, mse_min)
min_value, max_value = mse_min[1], mse_min[2]
return min_value, max_value
@_implements(_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY)
class _HistogramMseMaxFrequency(_HistogramCalibrationAlgorithmBase):
"""HistogramMseMaxFrequency for calculating min and max values of calibration result."""
def get_min_max_value(self) -> tuple[float, float]:
"""Finds min and max starting from the index of the max frequency.
The HistogramMseMaxFrequency method starts from the bin with the highest
frequency and expands the range to both sides. This performs well when data
is well spread on both sides of the max frequency.
Returns:
(min_value, max_value): Min and max calculated using method to expand the
range based on max frequency.
"""
# Find the index of max frequency.
freq_max_idx = np.argmax(self._hist_freq)
return self._get_min_max_value_by_expanding_range(freq_max_idx)
@_implements(_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC)
class _HistogramMseSymmetric(_HistogramCalibrationAlgorithmBase):
"""HistogramMseSymmetric for calculating min and max values of calibration result."""
def get_min_max_value(self) -> tuple[float, float]:
"""Finds min and max starting from the center index.
The HistogramMseSymmetric method starts from the center bin and expands the
range to both sides. This works better when the data is well-centered.
Returns:
(min_value, max_value): Min and max calculated using the method starting
from center and expanding.
"""
# This function is currently only called in this method, but will be used in
# other methods in the future.
return self._get_min_max_value_by_expanding_range(self._num_bins // 2)
def get_min_max_value(
statistics: calib_stats_pb2.CalibrationStatistics,
calib_opts: stablehlo_quant_config_pb2.CalibrationOptions,
) -> tuple[float, float]:
"""Calculates min and max from statistics using calibration options.
Args:
statistics: Collected calibration statistics.
calib_opts: Calibration options used for calculating min and max.
Returns:
(min_value, max_value): Min and max calculated using calib_opts.
Raises:
ValueError: Unsupported calibration method is given.
"""
calib_method = calib_opts.calibration_method
if calib_method not in _REGISTRY:
raise ValueError(f'Unsupported calibration method: {calib_method}')
calibration_algorithm = _REGISTRY[calib_method](statistics, calib_opts)
return calibration_algorithm.get_min_max_value()
@@ -0,0 +1,49 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\nXtensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.proto\x12\x15tensorflow.calibrator\"\x9c\x04\n\x15\x43\x61librationStatistics\x12Y\n\x12min_max_statistics\x18\x01 \x01(\x0b\x32=.tensorflow.calibrator.CalibrationStatistics.MinMaxStatistics\x12h\n\x1a\x61verage_min_max_statistics\x18\x02 \x01(\x0b\x32\x44.tensorflow.calibrator.CalibrationStatistics.AverageMinMaxStatistics\x12^\n\x14histogram_statistics\x18\x03 \x01(\x0b\x32@.tensorflow.calibrator.CalibrationStatistics.HistogramStatistics\x1a:\n\x10MinMaxStatistics\x12\x12\n\nglobal_min\x18\x01 \x01(\x02\x12\x12\n\nglobal_max\x18\x02 \x01(\x02\x1aP\n\x17\x41verageMinMaxStatistics\x12\x0f\n\x07min_sum\x18\x01 \x01(\x02\x12\x0f\n\x07max_sum\x18\x02 \x01(\x02\x12\x13\n\x0bnum_samples\x18\x03 \x01(\x05\x1aP\n\x13HistogramStatistics\x12\x11\n\tbin_width\x18\x01 \x01(\x02\x12\x13\n\x0blower_bound\x18\x02 \x01(\x02\x12\x11\n\thist_freq\x18\x03 \x03(\x02\"\xd0\x01\n\x18\x43\x61librationStatisticsMap\x12S\n\nstatistics\x18\x01 \x03(\x0b\x32?.tensorflow.calibrator.CalibrationStatisticsMap.StatisticsEntry\x1a_\n\x0fStatisticsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12;\n\x05value\x18\x02 \x01(\x0b\x32,.tensorflow.calibrator.CalibrationStatistics:\x02\x38\x01\x42\x03\xf8\x01\x01\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.compiler.mlir.quantization.tensorflow.calibrator.calibration_statistics_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
_globals['DESCRIPTOR']._loaded_options = None
_globals['DESCRIPTOR']._serialized_options = b'\370\001\001'
_globals['_CALIBRATIONSTATISTICSMAP_STATISTICSENTRY']._loaded_options = None
_globals['_CALIBRATIONSTATISTICSMAP_STATISTICSENTRY']._serialized_options = b'8\001'
_globals['_CALIBRATIONSTATISTICS']._serialized_start=116
_globals['_CALIBRATIONSTATISTICS']._serialized_end=656
_globals['_CALIBRATIONSTATISTICS_MINMAXSTATISTICS']._serialized_start=434
_globals['_CALIBRATIONSTATISTICS_MINMAXSTATISTICS']._serialized_end=492
_globals['_CALIBRATIONSTATISTICS_AVERAGEMINMAXSTATISTICS']._serialized_start=494
_globals['_CALIBRATIONSTATISTICS_AVERAGEMINMAXSTATISTICS']._serialized_end=574
_globals['_CALIBRATIONSTATISTICS_HISTOGRAMSTATISTICS']._serialized_start=576
_globals['_CALIBRATIONSTATISTICS_HISTOGRAMSTATISTICS']._serialized_end=656
_globals['_CALIBRATIONSTATISTICSMAP']._serialized_start=659
_globals['_CALIBRATIONSTATISTICSMAP']._serialized_end=867
_globals['_CALIBRATIONSTATISTICSMAP_STATISTICSENTRY']._serialized_start=772
_globals['_CALIBRATIONSTATISTICSMAP_STATISTICSENTRY']._serialized_end=867
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: tensorflow/compiler/mlir/quantization/tensorflow/exported_model.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'tensorflow/compiler/mlir/quantization/tensorflow/exported_model.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from tensorflow.core.framework import graph_pb2 as tensorflow_dot_core_dot_framework_dot_graph__pb2
from tensorflow.core.protobuf import meta_graph_pb2 as tensorflow_dot_core_dot_protobuf_dot_meta__graph__pb2
from tensorflow.core.protobuf import saver_pb2 as tensorflow_dot_core_dot_protobuf_dot_saver__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\nEtensorflow/compiler/mlir/quantization/tensorflow/exported_model.proto\x12\x17tensorflow.quantization\x1a%tensorflow/core/framework/graph.proto\x1a)tensorflow/core/protobuf/meta_graph.proto\x1a$tensorflow/core/protobuf/saver.proto\"\xbe\x03\n\rExportedModel\x12\'\n\tgraph_def\x18\x01 \x01(\x0b\x32\x14.tensorflow.GraphDef\x12\x16\n\x0einit_node_name\x18\x02 \x01(\t\x12\x16\n\x0e\x63heckpoint_dir\x18\x05 \x01(\t\x12U\n\x10\x66unction_aliases\x18\x06 \x03(\x0b\x32;.tensorflow.quantization.ExportedModel.FunctionAliasesEntry\x12\x31\n\x0f\x61sset_file_defs\x18\x08 \x03(\x0b\x32\x18.tensorflow.AssetFileDef\x12\'\n\tsaver_def\x18\n \x01(\x0b\x32\x14.tensorflow.SaverDef\x1a\x36\n\x14\x46unctionAliasesEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01J\x04\x08\x03\x10\x04J\x04\x08\x04\x10\x05J\x04\x08\x07\x10\x08J\x04\x08\t\x10\nR\x15variable_shared_namesR\x11restore_node_nameR\x0esave_node_nameR\x17\x66ile_prefix_tensor_nameb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.compiler.mlir.quantization.tensorflow.exported_model_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_EXPORTEDMODEL_FUNCTIONALIASESENTRY']._loaded_options = None
_globals['_EXPORTEDMODEL_FUNCTIONALIASESENTRY']._serialized_options = b'8\001'
_globals['_EXPORTEDMODEL']._serialized_start=219
_globals['_EXPORTEDMODEL']._serialized_end=665
_globals['_EXPORTEDMODEL_FUNCTIONALIASESENTRY']._serialized_start=504
_globals['_EXPORTEDMODEL_FUNCTIONALIASESENTRY']._serialized_end=558
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,770 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines a wrapper class for overridden python method definitions."""
from collections.abc import Callable, Collection, Mapping, Sequence
import functools
import traceback
from typing import Optional, TypeVar
from absl import logging
from tensorflow.compiler.mlir.quantization.stablehlo import quantization_config_pb2 as stablehlo_quant_config_pb2
from tensorflow.compiler.mlir.quantization.tensorflow import exported_model_pb2
from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2
from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_algorithm
from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_statistics_pb2
from tensorflow.compiler.mlir.quantization.tensorflow.python import pywrap_function_lib
from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as rd
from tensorflow.compiler.mlir.quantization.tensorflow.python import save_model
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.eager import wrap_function
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_conversion
from tensorflow.python.lib.io import file_io
from tensorflow.python.saved_model import load
from tensorflow.python.saved_model import loader_impl
from tensorflow.python.trackable import autotrackable
from tensorflow.python.types import core
# Name of the saved model assets directory.
_ASSETS_DIR = 'assets'
_ASSETS_EXTRA_DIR = 'assets.extra'
# Type variable for a type that is not `None`. This represents a return value of
# methods in `PyFunctionLibrary` that should not be `None`, as `None` represents
# that the execution was unsucessful, transfored as `std::nullopt_t` from c++.
NotNoneT = TypeVar('NotNoneT')
def _get_saver_def_or_none(
exported_model: exported_model_pb2.ExportedModel,
) -> Optional[saver_pb2.SaverDef]:
"""Returns the SaverDef from ExportedModel, None otherwise.
Args:
exported_model: ExportedModel to take the SaverDef from.
Returns:
SaverDef instance if the field `saver_def` is set. None otherwise.
"""
if exported_model.HasField('saver_def'):
return exported_model.saver_def
return None
def _copy_assets(src_path: str, dst_path: str) -> None:
"""Copies the assets directory of the saved model.
Clones the contents of the assets/ directory from the source saved model
directory to the destination saved model directory. Nothing will be copied if
there are no assets directory in the source directory.
Args:
src_path: Source saved model directory.
dst_path: Destination saved model directory. This directory must exist.
"""
for assets_dir_name in [_ASSETS_DIR, _ASSETS_EXTRA_DIR]:
src_assets_path = file_io.join(src_path, assets_dir_name)
if not file_io.file_exists_v2(src_assets_path):
# Do nothing if the source assets path does not exist.
continue
dst_assets_path = file_io.join(dst_path, assets_dir_name)
file_io.create_dir_v2(dst_assets_path)
for curr_dir, _, files in file_io.walk_v2(src_assets_path):
for asset_file_name in files:
src_asset_file = file_io.join(curr_dir, asset_file_name)
# Construct the destination assets file path.
curr_dst_dir = curr_dir.replace(src_assets_path, dst_assets_path)
dst_asset_file = file_io.join(curr_dst_dir, asset_file_name)
file_io.copy_v2(src_asset_file, dst_asset_file)
logging.info(
'Copied asset file: %s -> %s', src_asset_file, dst_asset_file
)
def _validate_representative_dataset(
representative_dataset: rd.RepresentativeDatasetOrMapping,
signature_keys: Collection[str],
) -> None:
"""Validates the representative dataset, based on the signature keys.
Representative dataset can be provided in two different forms: a single
instance of `RepresentativeDataset` or a map of signature key to the
corresponding `RepresentativeDataset`. These have a relationship with
`signature_keys`.
This function validates the following conditions:
* If `len(signature_keys) > 1`, then `representative_dataset` should be a
mapping where the keys exactly match the elements in `signature_keys`.
* If `len(signature_keys) == 1`, then both a mapping and a single instance of
`RepresentativeDataset` are allowed.
* This function also assumes `len(signature_keys) > 0`.
Args:
representative_dataset: A `RepresentativeDataset` or a map of string to
`RepresentativeDataset` to be validated.
signature_keys: A collection of strings that contains the signature keys,
each identifying a `SignatureDef`.
Raises:
ValueError: Iff `representative_dataset` does not satisfy the conditions
above.
"""
if isinstance(representative_dataset, Mapping):
if set(signature_keys) != set(representative_dataset.keys()):
raise ValueError(
'The signature keys and the keys of representative dataset map '
f'do not match. Signature keys: {set(signature_keys)}, '
f'representative dataset map: {set(representative_dataset.keys())}.'
)
else:
if len(signature_keys) > 1:
raise ValueError(
'Representative dataset is not a mapping '
f'(got: {type(representative_dataset)}), '
'but there is more than one signature key provided. '
'Please provide a map of {signature_key -> dataset} '
'with more than one signature key.'
)
def _replace_tensors_by_numpy_ndarrays(
repr_ds_map: rd.RepresentativeDatasetMapping,
) -> None:
"""Replaces tf.Tensors by their evaluated numpy arrays.
This assumes that tf.Tensors in representative samples are created in the
default Graph. It will raise an error if tensors are created in a different
graph.
Args:
repr_ds_map: SignatureDef key -> RepresentativeDataset mapping.
"""
with session.Session() as sess:
for signature_def_key in repr_ds_map:
# Replaces the dataset with a new dataset where tf.Tensors are replaced
# by their evaluated values.
ds = repr_ds_map[signature_def_key]
repr_ds_map[signature_def_key] = rd.replace_tensors_by_numpy_ndarrays(
ds, sess
)
def _create_sample_validator(
expected_input_keys: Collection[str],
) -> Callable[[rd.RepresentativeSample], rd.RepresentativeSample]:
"""Creates a validator function for a representative sample.
Args:
expected_input_keys: Input keys (keyword argument names) that the function
the sample will be used for is expecting to receive.
Returns:
A callable that validates a `RepresentativeSample`.
"""
def validator(
sample: rd.RepresentativeSample,
) -> rd.RepresentativeSample:
"""Validates a single instance of representative sample.
This provides a simple check for `sample` that this is a mapping of
{input_key: input_value}.
Args:
sample: A `RepresentativeSample` to validate.
Returns:
`sample` iff it is valid.
Raises:
ValueError: iff the sample isn't an instance of `Mapping`.
KeyError: iff the sample does not have the set of input keys that match
the input keys of the function.
"""
if not isinstance(sample, Mapping):
raise ValueError(
'Invalid representative sample type. Provide a mapping '
'(usually a dict) of {input_key: input_value}. '
f'Got type: {type(sample)} instead.'
)
if set(sample.keys()) != expected_input_keys:
raise KeyError(
'Invalid input keys for representative sample. The function expects '
f'input keys of: {set(expected_input_keys)}. '
f'Got: {set(sample.keys())}. Please provide correct input keys for '
'representative samples.'
)
return sample
return validator
# TODO(b/249918070): Implement a progress bar.
def _log_sample_num_for_calibration(
representative_dataset: rd.RepresentativeDataset,
) -> rd.RepresentativeDataset:
"""Logs the sample number for calibration.
If in debug logging level, the "sample number / total num samples" is logged
for every 5 iterations.
This is often useful when tracking the progress of the calibration step which
is often slow and may look stale if there's no logs being printed.
Args:
representative_dataset: The representative dataset.
Yields:
The representative samples from `representative_dataset` without any
modification.
"""
num_samples: Optional[int] = rd.get_num_samples(representative_dataset)
if num_samples is None:
total_num_samples = '?'
logging.info('Representative dataset size unknown.')
else:
total_num_samples = str(num_samples)
logging.info('Using representative dataset of size: %s', total_num_samples)
sample_num = 0
for sample in representative_dataset:
sample_num += 1
# Log the sample number for every 5 iterations.
logging.log_every_n(
logging.DEBUG,
'Running representative sample for calibration: %d / %s',
5,
sample_num,
total_num_samples,
)
yield sample
logging.info(
'Running representative samples complete: %d / %s',
sample_num,
total_num_samples,
)
def _run_function_for_calibration_graph_mode(
sess: session.Session,
signature_def: meta_graph_pb2.SignatureDef,
representative_dataset: rd.RepresentativeDataset,
) -> None:
"""Runs the representative dataset through a function for calibration.
NOTE: This is intended to be run in graph mode (TF1).
The function is identified by the SignatureDef.
Args:
sess: The Session object to run the function in.
signature_def: A SignatureDef that identifies a function by specifying the
inputs and outputs.
representative_dataset: The representative dataset to run through the
function.
"""
output_tensor_names = [
output_tensor_info.name
for output_tensor_info in signature_def.outputs.values()
]
sample_validator = _create_sample_validator(
expected_input_keys=signature_def.inputs.keys()
)
for sample in map(
sample_validator, _log_sample_num_for_calibration(representative_dataset)
):
# Create a mapping from input tensor name to the input tensor value.
# ex) "Placeholder:0" -> [0, 1, 2]
feed_dict = rd.create_feed_dict_from_input_data(sample, signature_def)
sess.run(output_tensor_names, feed_dict=feed_dict)
def _run_graph_for_calibration_graph_mode(
model_dir: str,
tags: Collection[str],
representative_dataset_map: rd.RepresentativeDatasetMapping,
) -> None:
"""Runs the graph for calibration in graph mode.
This function assumes _graph mode_ (used when legacy TF1 is used or when eager
mode is explicitly disabled) when running the graph. This step is used in
order to collect the statistics in CustomAggregatorOp for quantization using
the representative dataset for the actual data provided for inference.
Args:
model_dir: Path to SavedModel directory.
tags: Collection of tags identifying the MetaGraphDef within the SavedModel.
representative_dataset_map: A map where signature keys are mapped to
corresponding representative datasets.
Raises:
ValueError: When running the function with the representative dataset fails.
"""
# Replace tf.Tensors by numpy ndarrays in order to reuse the samples in a
# different graph when running the calibration.
_replace_tensors_by_numpy_ndarrays(representative_dataset_map)
# Run the calibration in a new graph to avoid name collision, which could
# happen when the same model is loaded multiple times in the default graph.
with ops.Graph().as_default(), session.Session() as sess:
meta_graph: meta_graph_pb2.MetaGraphDef = loader_impl.load(
sess, tags, export_dir=model_dir
)
for signature_key, repr_ds in representative_dataset_map.items():
sig_def = meta_graph.signature_def[signature_key]
try:
_run_function_for_calibration_graph_mode(
sess, signature_def=sig_def, representative_dataset=repr_ds
)
except Exception as ex:
raise ValueError(
'Failed to run representative dataset through the '
f'function with the signature key: {signature_key}.'
) from ex
def _convert_values_to_tf_tensors(
sample: rd.RepresentativeSample,
) -> Mapping[str, core.Tensor]:
"""Converts TensorLike values of `sample` to Tensors.
Creates a copy of `sample`, where each value is converted to Tensors
unless it is already a Tensor.
The values are not converted in-place (i.e. `sample` is not mutated).
Args:
sample: A representative sample, which is a map of {name -> tensorlike
value}.
Returns:
Converted map of {name -> tensor}.
"""
tensor_mapping = {}
for name, tensorlike_value in sample.items():
if isinstance(tensorlike_value, core.Tensor):
tensor_value = tensorlike_value
else:
tensor_value = tensor_conversion.convert_to_tensor_v2_with_dispatch(
tensorlike_value
)
tensor_mapping[name] = tensor_value
return tensor_mapping
def _run_function_for_calibration_eager_mode(
func: wrap_function.WrappedFunction,
representative_dataset: rd.RepresentativeDataset,
) -> None:
"""Runs the representative dataset through a function for calibration.
NOTE: This is intended to be run in eager mode (TF2).
Args:
func: The function to run the representative samples through.
representative_dataset: Representative dataset used for calibration. The
input keys and input values of the representative samples should match the
keyword arguments of `func`.
"""
_, keyword_args = func.structured_input_signature
sample_validator = _create_sample_validator(
expected_input_keys=keyword_args.keys()
)
for sample in map(
sample_validator, _log_sample_num_for_calibration(representative_dataset)
):
# Convert any non-Tensor values from the sample to Tensors.
# This conversion is required because the model saved in `model_dir` is
# saved using TF1 SavedModelBuilder, which doesn't save the
# SavedObjectGraph.
func_kwargs = _convert_values_to_tf_tensors(sample)
func(**func_kwargs)
def _run_graph_for_calibration_eager_mode(
model_dir: str,
tags: Collection[str],
representative_dataset_map: rd.RepresentativeDatasetMapping,
) -> None:
"""Runs the graph for calibration in eager mode.
This function assumes _eager mode_ (enabled in TF2 by default) when running
the graph. This step is used in order to collect the statistics in
CustomAggregatorOp for quantization using the representative dataset for the
actual data provided for inference.
Args:
model_dir: Path to SavedModel directory.
tags: Collection of tags identifying the MetaGraphDef within the SavedModel.
representative_dataset_map: A map where signature keys are mapped to
corresponding representative datasets.
Raises:
ValueError: When running the function with the representative dataset fails.
"""
root: autotrackable.AutoTrackable = load.load(model_dir, tags)
for signature_key, repr_ds in representative_dataset_map.items():
try:
_run_function_for_calibration_eager_mode(
func=root.signatures[signature_key], representative_dataset=repr_ds
)
except Exception as ex:
raise ValueError(
'Failed to run representative dataset through the '
f'function with the signature key: {signature_key}.'
) from ex
def _run_graph_for_calibration(
float_model_dir: str,
signature_keys: Sequence[str],
tags: Collection[str],
representative_dataset: rd.RepresentativeDatasetOrMapping,
force_graph_mode_calibration: bool,
) -> None:
"""Runs the graph for calibration using representative datasets.
Args:
float_model_dir: Path to the model to calibrate.
signature_keys: Sequence of keys identifying SignatureDef containing inputs
and outputs.
tags: Collection of tags identifying the MetaGraphDef within the SavedModel
to analyze.
representative_dataset: An iterator that returns a dictionary of {input_key:
input_value} or a mapping from signature keys to such iterators. When
`signature_keys` contains more than one signature key,
`representative_datsaet` should be a mapping that maps each signature keys
to the corresponding representative dataset.
force_graph_mode_calibration: If set to true, it forces calibration in graph
model instead of eager mode when the context is in eager mode.
Raises:
ValueError iff:
* The representative dataset format is invalid.
* It fails to run the functions using the representative datasets.
"""
try:
_validate_representative_dataset(representative_dataset, signature_keys)
except Exception as ex:
raise ValueError('Invalid representative dataset.') from ex
# If `representative_dataset` is not a mapping, convert to a mapping for the
# following functions to handle representative datasets more conveniently.
representative_dataset_map = representative_dataset
if not isinstance(representative_dataset, Mapping):
# `signature_keys` is guaranteed to have only one element after the
# validation.
representative_dataset_map = {signature_keys[0]: representative_dataset}
try:
if context.executing_eagerly() and not force_graph_mode_calibration:
logging.info('Calibration step is executed in eager mode.')
_run_graph_for_calibration_eager_mode(
float_model_dir, tags, representative_dataset_map
)
else:
logging.info('Calibration step is executed in graph mode.')
_run_graph_for_calibration_graph_mode(
float_model_dir, tags, representative_dataset_map
)
except Exception as ex:
raise ValueError(
'Failed to run graph for post-training quantization calibration.'
) from ex
logging.info('Calibration step complete.')
def _run_calibration(
saved_model_path: str,
signature_keys: Sequence[str],
tags: Collection[str],
force_graph_mode_calibration: bool,
representative_dataset_file_map: Mapping[
str, quantization_options_pb2.RepresentativeDatasetFile
],
) -> bool:
"""Runs calibration and adds calibration statistics to exported model.
Args:
saved_model_path: Path to the SavedModel to run calibration.
signature_keys: List of signature keys corresponding to SignatureDefs to run
calibration on.
tags: A set of tags that identify the MetaGraphDef.
force_graph_mode_calibration: If True, runs the calibration in graph mode.
representative_dataset_file_map: Signature key ->
`RepresentativeDatasetFile` mapping for running the calibration step. Each
dataset file stores the representative dataset for the function matching
the signature key.
Returns:
`True` upon successfully running calibration.
"""
repr_dataset_map = rd.TfRecordRepresentativeDatasetLoader(
representative_dataset_file_map
).load()
# Uses the representative dataset to collect statistics for calibration.
# After this operation, min & max values are stored separately in a global
# CalibratorSingleton instance.
_run_graph_for_calibration(
saved_model_path,
signature_keys,
tags,
repr_dataset_map,
force_graph_mode_calibration,
)
# Dummy value to indicate successful run, as `None` would indicate error. See
# comments in `NotNoneT`.
return True
def _call_and_return_none_on_error(
func: Callable[[], NotNoneT], error_msg: str
) -> Optional[NotNoneT]:
"""Calls `func` and returns `None` on error.
This is used to gracefully return the 'error status' represented as `None`, as
raising exceptions from `PyFunctionLibrary` methods crashes the program.
Args:
func: The function to run. The function should be a callable returning a
non-None value.
error_msg: The error message to log upon error. Used for debugging purposes.
Returns:
`None` if the function raises an exception. The return value of `func`
otherwise.
"""
try:
return func()
except Exception as ex: # pylint: disable=broad-exception-caught; Required for graceful failing with pybind11.
# Prints the exception traceback for debuggability.
traceback.print_exception(ex)
# Additional error log for debuggability.
logging.error(error_msg)
return None
def _save_model_and_copy_assets(
exported_model: exported_model_pb2.ExportedModel,
src_saved_model_path: str,
dst_saved_model_path: str,
signature_def_map: Mapping[str, meta_graph_pb2.SignatureDef],
tags: Collection[str],
) -> bool:
"""Saves the model and copies the assets from the source model.
Args:
exported_model: ExportedModel to save.
src_saved_model_path: Path to the source SavedModel. This will be used to
copy the asset files to `dst_saved_model_path`.
dst_saved_model_path: Destination path to save the exported model.
signature_def_map: Signature key -> SignatureDef mapping.
tags: Tags to attach to the saved MetaGraphDef.
Returns:
`True` upon successfully saving the model.
"""
save_model.save_model_v1(
exported_model.graph_def,
dst_saved_model_path,
signature_def_map,
tags,
init_op_name=exported_model.init_node_name,
saver_def=_get_saver_def_or_none(exported_model),
checkpoint_dir=exported_model.checkpoint_dir,
function_aliases=exported_model.function_aliases,
asset_file_defs=exported_model.asset_file_defs,
)
_copy_assets(src_saved_model_path, dst_saved_model_path)
# Dummy value to indicate successful run, as `None` would indicate error. See
# comments in `NotNoneT`.
return True
class PyFunctionLibrary(pywrap_function_lib.PyFunctionLibrary):
"""Wrapper class for overridden python method definitions.
This class contains python methods that overrides C++ virtual functions
declared in `pywrap_function_lib.PyFunctionLibrary`.
"""
# LINT.IfChange(save_exported_model)
def save_exported_model(
self,
dst_saved_model_path: str,
exported_model_serialized: bytes,
src_saved_model_path: str,
tags: set[str],
serialized_signature_def_map: dict[str, bytes],
) -> Optional[bool]:
# LINT.ThenChange(py_function_lib.h:save_exported_model)
"""Saves `ExportedModel` to `dst_saved_model_path` as a SavedModel.
Args:
dst_saved_model_path: Destination path to save the exported model.
exported_model_serialized: Exported model to export as SavedModel.
src_saved_model_path: Path to the source SavedModel. This will be used to
copy the asset files to `dst_saved_model_path`.
tags: Tags to attach to the saved MetaGraphDef.
serialized_signature_def_map: Signature key -> serialized SignatureDef.
Returns:
`True` upon successful execution. `None` when an error is raised
internally.
"""
exported_model = exported_model_pb2.ExportedModel.FromString(
exported_model_serialized
)
# Deserialize values in signature_def_map.
signature_def_map = {}
for key, serialized_signature_def in serialized_signature_def_map.items():
signature_def_map[key] = meta_graph_pb2.SignatureDef.FromString(
serialized_signature_def
)
return _call_and_return_none_on_error(
func=functools.partial(
_save_model_and_copy_assets,
exported_model,
src_saved_model_path,
dst_saved_model_path,
signature_def_map,
tags,
),
error_msg=(
f'Failed to save model "{dst_saved_model_path}",'
f' signature_def_map: {signature_def_map}, tags: {tags}.'
),
)
# TODO: b/311097139 - Extract calibration related functions into a separate
# file.
# LINT.IfChange(run_calibration)
def run_calibration(
self,
saved_model_path: str,
signature_keys: list[str],
tags: set[str],
force_graph_mode_calibration: bool,
representative_dataset_file_map_serialized: dict[str, bytes],
) -> Optional[bool]:
# LINT.ThenChange(py_function_lib.h:run_calibration)
"""Runs calibration and adds calibration statistics to exported model.
Args:
saved_model_path: Path to the SavedModel to run calibration.
signature_keys: List of signature keys corresponding to SignatureDefs to
run calibration on.
tags: A set of tags that identify the MetaGraphDef.
force_graph_mode_calibration: If True, runs the calibration in graph mode.
representative_dataset_file_map_serialized: Signature key ->
`RepresentativeDatasetFile` mapping for running the calibration step.
Each dataset file stores the representative dataset for the function
matching the signature key.
Returns:
The error message if the function raises and exception. `None` otherwise.
"""
# Deserialize `RepresentativeDatasetFile` values.
dataset_file_map = {}
for (
signature_key,
dataset_file_serialized,
) in representative_dataset_file_map_serialized.items():
dataset_file_map[signature_key] = (
quantization_options_pb2.RepresentativeDatasetFile.FromString(
dataset_file_serialized
)
)
return _call_and_return_none_on_error(
func=functools.partial(
_run_calibration,
saved_model_path,
signature_keys,
tags,
force_graph_mode_calibration,
dataset_file_map,
),
error_msg=(
f'Failed to run calibration on model "{saved_model_path}",'
f' signature_keys: {signature_keys}, tags: {tags}.'
),
)
# LINT.IfChange(get_calibration_min_max_value)
def get_calibration_min_max_value(
self,
calibration_statistics_serialized: bytes,
calibration_options_serialized: bytes,
) -> Optional[tuple[float, float]]:
"""Calculates min and max values from statistics.
Args:
calibration_statistics_serialized: Serialized `CalibrationStatistics`.
This will be the source to calculate min and max values from.
calibration_options_serialized: Serialized `CalibrationOptions`. Specifies
how the min / max should be calculated.
Returns:
(min_value, max_value): Min and max calculated using calib_opts. `None`
upon error.
"""
# LINT.ThenChange(py_function_lib.h:get_calibration_min_max_value)
# Deserialize values passed from c++.
statistics = calibration_statistics_pb2.CalibrationStatistics.FromString(
calibration_statistics_serialized
)
options = stablehlo_quant_config_pb2.CalibrationOptions.FromString(
calibration_options_serialized
)
return _call_and_return_none_on_error(
functools.partial(
calibration_algorithm.get_min_max_value,
statistics,
options,
),
error_msg=(
f'Retrieving calibrated min / max failed. Options: {options}.'
),
)
@@ -0,0 +1,48 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Any, Optional
class PyFunctionLibrary:
# LINT.IfChange(save_exported_model)
def save_exported_model(
self,
dst_saved_model_path: str,
exported_model_serialized: bytes,
src_saved_model_path: str,
tags: set[str],
serialized_signature_def_map: dict[str, bytes],
) -> Optional[bool]: ...
# LINT.ThenChange()
# LINT.IfChange(run_calibration)
def run_calibration(
self,
saved_model_path: str,
signature_keys: list[str],
tags: set[str],
force_graph_mode_calibration: bool,
# Value type: RepresentativeDatasetFile.
representative_dataset_file_map_serialized: dict[str, bytes],
) -> Optional[bool]: ...
# LINT.ThenChange()
# LINT.IfChange(get_calibration_min_max_value)
def get_calibration_min_max_value(
self,
calibration_statistics_serialized: bytes,
calibration_options_serialized: bytes,
) -> Optional[tuple[float, float]]: ...
# LINT.ThenChange()
@@ -0,0 +1,72 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Any
from tensorflow.compiler.mlir.quantization.tensorflow.calibrator import calibration_statistics_pb2
from tensorflow.compiler.mlir.quantization.tensorflow.python import py_function_lib
from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as rd
# LINT.IfChange(quantize_qat_model)
def quantize_qat_model(
src_saved_model_path: str,
dst_saved_model_path: str,
quantization_options_serialized: bytes,
*,
signature_keys: list[str],
signature_def_map_serialized: dict[str, bytes],
py_function_library: py_function_lib.PyFunctionLibrary,
) -> Any: ... # Status
# LINT.ThenChange()
# LINT.IfChange(quantize_ptq_dynamic_range)
def quantize_ptq_dynamic_range(
src_saved_model_path: str,
dst_saved_model_path: str,
quantization_options_serialized: bytes,
*,
signature_keys: list[str],
signature_def_map_serialized: dict[str, bytes],
py_function_library: py_function_lib.PyFunctionLibrary,
) -> Any: ... # Status
# LINT.ThenChange()
# LINT.IfChange(quantize_weight_only)
def quantize_weight_only(
src_saved_model_path: str,
dst_saved_model_path: str,
quantization_options_serialized: bytes,
*,
signature_def_map_serialized: dict[str, bytes],
py_function_library: py_function_lib.PyFunctionLibrary,
) -> Any: ... # Status
# LINT.ThenChange()
# LINT.IfChange(quantize_ptq_static_range)
def quantize_ptq_static_range(
src_saved_model_path: str,
dst_saved_model_path: str,
quantization_options_serialized: bytes,
*,
signature_keys: list[str],
signature_def_map_serialized: dict[str, bytes],
py_function_library: py_function_lib.PyFunctionLibrary,
# Value type: RepresentativeDatasetFile.
representative_dataset_file_map_serialized: dict[str, bytes],
) -> Any: ... # Status
# LINT.ThenChange()
@@ -0,0 +1,926 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines TF Quantization API from SavedModel to SavedModel."""
import tempfile
from typing import Mapping, Optional
from absl import logging
from tensorflow.compiler.mlir.quantization.stablehlo import quantization_config_pb2 as stablehlo_quant_config_pb2
from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2 as quant_opts_pb2
from tensorflow.compiler.mlir.quantization.tensorflow.python import py_function_lib
from tensorflow.compiler.mlir.quantization.tensorflow.python import pywrap_quantize_model
from tensorflow.compiler.mlir.quantization.tensorflow.python import representative_dataset as repr_dataset
from tensorflow.compiler.mlir.quantization.tensorflow.python import save_model
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.lib.io import file_io
from tensorflow.python.saved_model import load as saved_model_load
from tensorflow.python.saved_model import loader_impl as saved_model_loader
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.trackable import autotrackable
from tensorflow.python.util import tf_export
# Type aliases for quant_opts_pb2 messages.
_QuantizationOptions = tf_export.tf_export(
'quantization.experimental.QuantizationOptions'
)(quant_opts_pb2.QuantizationOptions)
_QuantizationMethod = tf_export.tf_export(
'quantization.experimental.QuantizationMethod'
)(quant_opts_pb2.QuantizationMethod)
_QuantizationComponentSpec = tf_export.tf_export(
'quantization.experimental.QuantizationComponentSpec'
)(quant_opts_pb2.QuantizationComponentSpec)
_UnitWiseQuantizationSpec = tf_export.tf_export(
'quantization.experimental.UnitWiseQuantizationSpec'
)(quant_opts_pb2.UnitWiseQuantizationSpec)
_PresetMethod = _QuantizationMethod.PresetMethod
_CalibrationMethod = (
stablehlo_quant_config_pb2.CalibrationOptions.CalibrationMethod
)
_QuantizationComponent = _QuantizationComponentSpec.QuantizationComponent
_TensorType = _QuantizationComponentSpec.TensorType
_RepresentativeDatasetFile = quant_opts_pb2.RepresentativeDatasetFile
# Mapping of signature def key -> SignatureDef.
_SignatureDefMap = Mapping[str, meta_graph_pb2.SignatureDef]
# Default minimum number of elements in the weights for them to be quantized
# during dynamic range quantization (DRQ) and weight-only quantization.
_DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS = 1024
def _is_qat_saved_model(saved_model_path: str):
"""Checks if the SavedModel is QAT-enabled by looking for 'FakeQuant' ops."""
saved_model_proto = saved_model_loader.parse_saved_model(saved_model_path)
for meta_graph in saved_model_proto.meta_graphs:
if any(
node.op.startswith('FakeQuant') for node in meta_graph.graph_def.node
):
return True
for function in meta_graph.graph_def.library.function:
if any(node.op.startswith('FakeQuant') for node in function.node_def):
return True
return False
def _serialize_signature_def_map(
signature_def_map: _SignatureDefMap,
) -> dict[str, bytes]:
"""Serializes SignatureDef values in `signature_def_map`.
Args:
signature_def_map: Signature key -> SignatureDef mapping.
Returns:
Signature def map where the values (`SignatureDef`) are serialized.
"""
signature_def_map_serialized = {}
for key, signature_def in signature_def_map.items():
signature_def_map_serialized[key] = signature_def.SerializeToString()
return signature_def_map_serialized
def _save_representative_dataset(
representative_dataset: repr_dataset.RepresentativeDatasetOrMapping,
signature_def_map: _SignatureDefMap,
) -> Mapping[str, _RepresentativeDatasetFile]:
"""Saves the representative dataset to temporary TFRecord files.
Args:
representative_dataset: Representative dataset used for the calibration
step. Representative datasets should exist for each signature def key in
`signature_def_keys`.
signature_def_map: Signature def key -> SignatureDef mapping.
Returns:
A map from signature key to the saved representative dataset file.
"""
if isinstance(representative_dataset, Mapping):
if set(signature_def_map.keys()) != set(representative_dataset.keys()):
raise ValueError(
'The signature keys and the keys of representative dataset map '
f'do not match. Signature keys: {set(signature_def_map.keys())}, '
f'representative dataset map: {set(representative_dataset.keys())}.'
)
representative_dataset_map = representative_dataset
elif len(signature_def_map.keys()) > 1:
raise ValueError(
'Representative dataset is not a mapping (got: '
f'{type(representative_dataset)}), but there is more than one '
'signature key provided. Please provide a map of '
'{signature_key -> dataset} with more than one signature key.'
)
else:
representative_dataset_map = {
list(signature_def_map.keys())[0]: representative_dataset,
}
# Save the representative dataset to temporary TFRecord files.
path_map = {}
expected_input_key_map = {}
for signature_key, signature_def in signature_def_map.items():
# Filepath is the second return value of mkstemp.
_, path_map[signature_key] = tempfile.mkstemp(
suffix='.tfrecord', prefix=signature_key
)
expected_input_key_map[signature_key] = signature_def.inputs.keys()
return repr_dataset.TfRecordRepresentativeDatasetSaver(
path_map=path_map,
expected_input_key_map=expected_input_key_map,
).save(representative_dataset_map)
def _run_static_range_qat(
src_saved_model_path: str,
dst_saved_model_path: str,
quant_opts: _QuantizationOptions,
signature_def_map: _SignatureDefMap,
) -> None:
"""Runs static-range quantization for a Quantization-Aware Trained model.
Runs the quantization for a model trained using QAT.
Args:
src_saved_model_path: Path to the source SavedModel directory.
dst_saved_model_path: Path to the destination SavedModel directory.
quant_opts: Quantization options.
signature_def_map: Signature def key -> SignatureDef mapping.
"""
logging.info('Running static-range quantization for QAT model.')
pywrap_quantize_model.quantize_qat_model(
src_saved_model_path,
dst_saved_model_path,
quantization_options_serialized=quant_opts.SerializeToString(),
signature_keys=list(quant_opts.signature_keys),
signature_def_map_serialized=_serialize_signature_def_map(
signature_def_map
),
py_function_library=py_function_lib.PyFunctionLibrary(),
)
def _run_static_range_ptq(
src_saved_model_path: str,
dst_saved_model_path: str,
quant_opts: _QuantizationOptions,
representative_dataset: Mapping[str, _RepresentativeDatasetFile],
signature_def_map: _SignatureDefMap,
) -> None:
"""Runs static-range Post-Training Quantization.
Runs static-range PTQ for the model. Runs the calibration step with
`representative_dataset` to collect statistics required for quantization. This
produces the quantized GraphDef along with the SignatureDefs which might have
been modified according to the changes in the graph.
Args:
src_saved_model_path: Path to the source SavedModel directory.
dst_saved_model_path: Path to the destination SavedModel directory.
quant_opts: Quantization options.
representative_dataset: A map from signature key to the saved representative
dataset file.
signature_def_map: Signature def key -> SignatureDef mapping.
Raises:
ValueError if the graph doesn't contain a valid signature.
"""
logging.info('Running static-range post-training quantization.')
signature_def_map_serialized = _serialize_signature_def_map(signature_def_map)
# `quantize_ptq_static_range` requires `RepresentativeDatasetFile`s to be
# serialized. Serialize the values to match the type.
dataset_file_map_serialized = {
signature_key: dataset_file.SerializeToString()
for signature_key, dataset_file in representative_dataset.items()
}
pywrap_quantize_model.quantize_ptq_static_range(
src_saved_model_path,
dst_saved_model_path,
quantization_options_serialized=quant_opts.SerializeToString(),
signature_keys=list(quant_opts.signature_keys),
signature_def_map_serialized=signature_def_map_serialized,
py_function_library=py_function_lib.PyFunctionLibrary(),
representative_dataset_file_map_serialized=dataset_file_map_serialized,
)
def _static_range_quantize(
src_saved_model_path: str,
dst_saved_model_path: str,
quantization_options: _QuantizationOptions,
representative_dataset: Optional[
repr_dataset.RepresentativeDatasetOrMapping
] = None,
) -> autotrackable.AutoTrackable:
"""Quantizes the given SavedModel via static range quantization.
If the model is not trained with Quantization-Aware Training (QAT) technique,
it requires `representative_dataset` to collect statistics required for
quantization. If non-None `representative_dataset` is provided with a QAT
model input, `representative_dataset` will be ignored.
Args:
src_saved_model_path: Path to the saved model. When representative_dataset
is not provided, this should be a model trained with QAT.
dst_saved_model_path: The path to save the output SavedModel. The directory
will be overwritten if not empty.
quantization_options: QuantizationOptions proto describing quantization
related config.
representative_dataset: a generator that returns a dictionary in {input_key:
input_value} format or a tuple with signature key and a dictionary in
{input_key: input_value} format that feeds calibration data for quantizing
model. This should be provided when the model is not a QAT model.
Returns:
A SavedModel object with TF quantization applied.
Raises:
ValueError: when representative_dataset is not provided for non-QAT model.
RuntimeError: When a MetaGraphDef could not be found associated with `tags`
in the SavedModel.
"""
logging.info(
'Running static range quantization on model: %s', src_saved_model_path
)
logging.info('QuantizationOptions: \n%s', quantization_options)
is_qat_saved_model_or_method_no_quantize = _is_qat_saved_model(
src_saved_model_path
) or (
quantization_options.quantization_method.preset_method
== _QuantizationMethod.METHOD_NO_QUANTIZE
)
signature_def_map = save_model.get_signatures_from_saved_model(
src_saved_model_path,
quantization_options.signature_keys,
set(quantization_options.tags),
)
if (
representative_dataset is not None
and quantization_options.representative_datasets
):
raise ValueError(
'Do not specify both the `representative_dataset` argument and'
' the `representative_datasets` field in `QuantizationOptions`.'
)
saved_representative_dataset = quantization_options.representative_datasets
if representative_dataset is not None:
saved_representative_dataset = _save_representative_dataset(
representative_dataset, signature_def_map
)
# Checks if the model is from QAT or method is METHOD_NO_QUANTIZE.
if (
not saved_representative_dataset
and not is_qat_saved_model_or_method_no_quantize
):
raise ValueError(
'When `representative_dataset` is not provided, the model should be '
'trained with quantization-aware training (QAT).'
)
if quantization_options.min_num_elements_for_weights > 0:
logging.warn(
'min_num_elements_for_weights is set but is not supported for the '
'Post-training static range quantization. '
'The flag is ignored.'
)
if is_qat_saved_model_or_method_no_quantize:
_run_static_range_qat(
src_saved_model_path,
dst_saved_model_path,
quantization_options,
signature_def_map,
)
else:
_run_static_range_ptq(
src_saved_model_path,
dst_saved_model_path,
quantization_options,
saved_representative_dataset,
signature_def_map,
)
return saved_model_load.load(dst_saved_model_path)
def _dynamic_range_quantize(
src_saved_model_path: str,
dst_saved_model_path: str,
quantization_options: _QuantizationOptions,
) -> autotrackable.AutoTrackable:
"""Quantizes the given SavedModel via post-training dynamic range quantization.
Args:
src_saved_model_path: Path to the saved model.
dst_saved_model_path: The path to save the output SavedModel. The directory
will be overwritten if not empty.
quantization_options: QuantizationOptions proto describing quantization
related config.
Returns:
A SavedModel object with TF quantization applied.
Raises:
ValueError: when the model is QAT model.
"""
mode_str = 'dynamic-range quantization'
if _is_qat_saved_model(src_saved_model_path):
raise ValueError(
'The models trained with quantization-aware training (QAT) is not '
'supported for %s.' % mode_str
)
logging.info(
'Running post-training %s on model: %s', mode_str, src_saved_model_path
)
logging.info('QuantizationOptions: \n%s', quantization_options)
signature_def_map = save_model.get_signatures_from_saved_model(
src_saved_model_path,
quantization_options.signature_keys,
quantization_options.tags,
)
# Apply post-training dynamic range quantization to the model.
pywrap_quantize_model.quantize_ptq_dynamic_range(
src_saved_model_path,
dst_saved_model_path,
quantization_options_serialized=quantization_options.SerializeToString(),
signature_keys=list(quantization_options.signature_keys),
signature_def_map_serialized=_serialize_signature_def_map(
signature_def_map
),
py_function_library=py_function_lib.PyFunctionLibrary(),
)
return saved_model_load.load(dst_saved_model_path)
def _weight_only_quantize(
src_saved_model_path: str,
dst_saved_model_path: str,
quantization_options: quant_opts_pb2.QuantizationOptions,
) -> autotrackable.AutoTrackable:
"""Quantizes the given SavedModel via weight-only quantization.
Args:
src_saved_model_path: Path to the saved model.
dst_saved_model_path: The path to save the output SavedModel. The directory
will be overwritten if not empty.
quantization_options: QuantizationOptions proto describing quantization
related config.
Returns:
A SavedModel object with TF quantization applied.
Raises:
ValueError: when the model is QAT model.
"""
mode_str = 'weight-only quantization'
# QAT weight-only is not supported yet.
if _is_qat_saved_model(src_saved_model_path):
raise ValueError(
'The models trained with quantization-aware training (QAT) is not '
'supported for %s.' % mode_str
)
logging.info(
'Running post-training %s on model: %s', mode_str, src_saved_model_path
)
logging.info('QuantizationOptions: \n%s', quantization_options)
signature_def_map = save_model.get_signatures_from_saved_model(
src_saved_model_path,
list(quantization_options.signature_keys),
set(quantization_options.tags),
)
pywrap_quantize_model.quantize_weight_only(
src_saved_model_path,
dst_saved_model_path,
quantization_options_serialized=quantization_options.SerializeToString(),
signature_def_map_serialized=_serialize_signature_def_map(
signature_def_map
),
py_function_library=py_function_lib.PyFunctionLibrary(),
)
return saved_model_load.load(dst_saved_model_path)
def _verify_output_dir(output_dir: Optional[str], overwrite: bool) -> None:
"""Verifies the output directory.
Raises an error if `output_dir` is not suitable for writing the output saved
model.
Args:
output_dir: Output directory.
overwrite: An option allowing to overwrite the existing output directory if
set to true. Does not actually create or modify the `output_dir` in this
function.
Raises:
FileExistsError: Iff `output_dir` is not empty and `overwrite` is false.
"""
dir_not_empty = (
output_dir is not None
and file_io.file_exists_v2(output_dir)
and file_io.list_directory_v2(output_dir)
)
if dir_not_empty and not overwrite:
raise FileExistsError(
f'Output directory already exists: {output_dir} . '
'Please set overwrite_output_directory to true to '
'overwrite the existing directory.'
)
def _populate_quantization_component_spec(
quant_method: _QuantizationMethod,
) -> None:
"""Populates default values for QuantizationComponentSpec.
Args:
quant_method: The quantization method to be updated.
"""
# Make sure creating one spec per component.
updated_component_spec = dict()
# Populate default configuration.
if (
quant_method.preset_method == _PresetMethod.METHOD_STATIC_RANGE_INT8
or quant_method.preset_method == _PresetMethod.METHOD_DYNAMIC_RANGE_INT8
):
updated_component_spec[_QuantizationComponent.COMPONENT_ACTIVATION] = (
_QuantizationComponentSpec(
quantization_component=_QuantizationComponent.COMPONENT_ACTIVATION,
tensor_type=_TensorType.TENSORTYPE_INT_8,
)
)
updated_component_spec[_QuantizationComponent.COMPONENT_WEIGHT] = (
_QuantizationComponentSpec(
quantization_component=_QuantizationComponent.COMPONENT_WEIGHT,
tensor_type=_TensorType.TENSORTYPE_INT_8,
)
)
updated_component_spec[_QuantizationComponent.COMPONENT_BIAS] = (
_QuantizationComponentSpec(
quantization_component=_QuantizationComponent.COMPONENT_BIAS,
tensor_type=_TensorType.TENSORTYPE_INT_32,
)
)
elif (
quant_method.preset_method
== _PresetMethod.METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8
):
updated_component_spec[_QuantizationComponent.COMPONENT_WEIGHT] = (
_QuantizationComponentSpec(
quantization_component=_QuantizationComponent.COMPONENT_WEIGHT,
tensor_type=_TensorType.TENSORTYPE_INT_8,
)
)
# Override if quantization_component_spec is specified.
if quant_method.quantization_component_specs:
# Check if the component spec is supported configuration in TF-Quant.
for component_spec in quant_method.quantization_component_specs:
if component_spec.quantization_component in [
_QuantizationComponent.COMPONENT_WEIGHT,
_QuantizationComponent.COMPONENT_ACTIVATION,
]:
if component_spec.tensor_type != _TensorType.TENSORTYPE_INT_8:
raise ValueError(
'Only int8 precision is supported for input operands.'
)
else:
if component_spec.tensor_type != _TensorType.TENSORTYPE_INT_32:
raise ValueError('Only int32 precision is supported for bias.')
# Update with the custom spec.
updated_component_spec[component_spec.quantization_component] = (
component_spec
)
# Update the componet spec
del quant_method.quantization_component_specs[:]
quant_method.quantization_component_specs.extend(
updated_component_spec.values()
)
if (
quant_method.preset_method == _PresetMethod.METHOD_STATIC_RANGE_INT8
or quant_method.preset_method == _PresetMethod.METHOD_DYNAMIC_RANGE_INT8
) and (len(quant_method.quantization_component_specs) != 3):
raise ValueError('Only 3 components are needed for', quant_method)
elif (
quant_method.preset_method
== _PresetMethod.METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8
) and len(quant_method.quantization_component_specs) != 1:
raise ValueError('At least one component spec needs to be specified.')
def _populate_unitwise_quantization_specs(
quantization_options: _QuantizationOptions,
) -> None:
"""Verifies and pupulates unitwise quantization specs."""
if not quantization_options.unit_wise_quantization_specs:
return
sorted_top_level_component_specs = sorted(
quantization_options.quantization_method.quantization_component_specs,
key=lambda x: x.quantization_component,
)
for unitwise_spec in quantization_options.unit_wise_quantization_specs:
if not unitwise_spec.unit:
raise ValueError(
'UnitWiseQuantizationSpec must contain at least one unit.'
)
for unit in unitwise_spec.unit:
if not unit.op_type and not unit.node_name:
raise ValueError('Either `op_type` or `node_name` must be specified.')
_populate_quantization_component_spec(unitwise_spec.quantization_method)
component_specs = (
unitwise_spec.quantization_method.quantization_component_specs
)
if component_specs and (
sorted_top_level_component_specs
!= sorted(component_specs, key=lambda x: x.quantization_component)
):
raise ValueError(
'Currently unit-wise quantization spec only supports NO_QUANTIZE and'
' same quantization method as the top-level `quantization_method`'
)
def _populate_calibration_options(
quantization_options: quant_opts_pb2.QuantizationOptions,
):
"""Populates default values for CalibrationOptions.
Args:
quantization_options: An instance of QuantizationOptions with a field
specifying CalibrationOptions
"""
calib_opts = quantization_options.calibration_options
if (
calib_opts.calibration_method
== _CalibrationMethod.CALIBRATION_METHOD_UNSPECIFIED
):
calib_opts.calibration_method = (
_CalibrationMethod.CALIBRATION_METHOD_MIN_MAX
)
elif (
calib_opts.calibration_method
== _CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_PERCENTILE
):
if not calib_opts.calibration_parameters.num_bins:
calib_opts.calibration_parameters.num_bins = 512
if not calib_opts.calibration_parameters.min_percentile:
calib_opts.calibration_parameters.min_percentile = 0.001
if not calib_opts.calibration_parameters.max_percentile:
calib_opts.calibration_parameters.max_percentile = 99.999
# Check the activation_tensor_type of HISTOGRAM_MSE methods.
elif calib_opts.calibration_method in [
_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE,
_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY,
_CalibrationMethod.CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC,
]:
activation_tensor_type = (
quantization_options.quantization_method.quantization_component_specs[
_QuantizationComponent.COMPONENT_ACTIVATION
].tensor_type
)
# Unlike the HISTOGRAM_PERCENTILE method, the HISTOGRAM_MSE method uses
# num_bits because it actually quantizes and dequantizes values.
if activation_tensor_type != _TensorType.TENSORTYPE_INT_8:
raise ValueError(
'Only TENSORTYPE_INT_8 is supported for HISTOGRAM_MSE calibration'
f' methods. calibration_method={calib_opts.calibration_method}'
)
if not calib_opts.calibration_parameters.num_bins:
calib_opts.calibration_parameters.num_bins = 512
if calib_opts.calibration_data_dir:
save_model.create_empty_output_dir(
calib_opts.calibration_data_dir,
overwrite=calib_opts.force_regenerate_calibration_data,
)
def _populate_quantization_options_default_values(
quantization_options: _QuantizationOptions,
) -> None:
"""Populates default values for QuantizationOptions.
Populates unspecified or unset fields of QuantizationOptions with the default
values.
* If `op_set` is unspecified, it defaults to `OpSet.XLA`.
* If `freeze_all_variables` is not set, it defaults to `True`.
* Check if configurations are set correctly:
- Per-channel quantization is supported for Uniform Quantized opset only.
Args:
quantization_options: An instance of QuantizationOptions.
"""
if quantization_options.op_set == quant_opts_pb2.OpSet.OP_SET_UNSPECIFIED:
quantization_options.op_set = quant_opts_pb2.OpSet.XLA
if not quantization_options.tags:
quantization_options.tags.append(tag_constants.SERVING)
if not quantization_options.signature_keys:
quantization_options.signature_keys.append(
signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
)
if not quantization_options.HasField('freeze_all_variables'):
quantization_options.freeze_all_variables = True
if quantization_options.enable_legacy_weight_only:
raise ValueError(
'Legacy weight-only is deprecated. Use weight-only quantization method.'
)
# Converter assumes options are specified. So set SRQ explicitly.
if (
quantization_options.quantization_method.preset_method
== _PresetMethod.METHOD_UNSPECIFIED
):
logging.debug(
'"preset_method" for QuantizationMethod is not specified.'
'Static range quantization is used by default.'
)
quantization_options.quantization_method.preset_method = (
_PresetMethod.METHOD_STATIC_RANGE_INT8
)
# Check default quantization option values for weight-only quantization.
# TODO(b/242805842): Find good minimum_elements_for_weights number for server.
# please also update default value in tflite converter:
# tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc;l=201
if quantization_options.min_num_elements_for_weights == 0:
quantization_options.min_num_elements_for_weights = (
_DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS
)
logging.warning(
(
'QuantizationOptions.min_num_elements_for_weights is not set (0).'
' Setting to the default value: %d.'
),
_DYNAMIC_RANGE_DEFAULT_MIN_NUM_ELEMENTS_FOR_WEIGHTS,
)
if not quantization_options.HasField('enable_per_channel_quantization'):
quantization_options.enable_per_channel_quantization = False
if quantization_options.enable_per_channel_quantization and not (
(
quantization_options.op_set == quant_opts_pb2.OpSet.UNIFORM_QUANTIZED
or quantization_options.quantization_method.preset_method
== _PresetMethod.METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8
)
or (
quantization_options.op_set
in (quant_opts_pb2.OpSet.XLA, quant_opts_pb2.OpSet.STABLEHLO)
and quantization_options.quantization_method.preset_method
== _PresetMethod.METHOD_STATIC_RANGE_INT8
)
):
raise ValueError(
'Currently, per-channel quantization is supported for Uniform Quantized'
' opset, weight only quantization, or XLA/StableHLO opset with static'
' range quantization.'
)
if (
quantization_options.quantization_method.preset_method
== _PresetMethod.METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8
and (
quantization_options.op_set == quant_opts_pb2.OpSet.UNIFORM_QUANTIZED
or quantization_options.op_set == quant_opts_pb2.OpSet.TF
)
):
raise ValueError('TF/Uniform quantized opset does not support weight-only.')
if (quantization_options.op_set == quant_opts_pb2.OpSet.STABLEHLO) and (
quantization_options.quantization_method.preset_method
!= _PresetMethod.METHOD_STATIC_RANGE_INT8
and quantization_options.quantization_method.preset_method
!= _PresetMethod.METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8
):
raise ValueError(
'StableHLO quantized opset currently only supports static range'
' quantization and weight-only quantizationvia TF Quantizer.'
)
# Set `force_graph_mode_calibration` to True to avoid skipping op execution,
# which are not connected to return ops, during calibration execution.
# TODO: b/335031954 - Bring back support to run calibration in Eager mode.
logging.debug(
'Setting `force_graph_mode_calibration = True` to ensure the calibration'
' mode is executed properly.'
)
quantization_options.force_graph_mode_calibration = True
if quantization_options.HasField('debugger_config'):
if not quantization_options.debugger_config.log_dir_path:
quantization_options.debugger_config.log_dir_path = '/tmp/dumps'
if (
quantization_options.debugger_config.debugger_type
== stablehlo_quant_config_pb2.DebuggerConfig.DebuggerType.DEBUGGER_TYPE_UNSPECIFIED
):
raise ValueError(
'Debugger is enabled but debugger type was not specified.'
)
if (
quantization_options.debugger_config.debugger_type
== stablehlo_quant_config_pb2.DebuggerConfig.DebuggerType.DEBUGGER_TYPE_WHOLE_MODEL
and not quantization_options.debugger_config.unquantized_dump_model_path
):
raise ValueError(
'Debugger type whole model verify was used but'
' unquantized_dump_model_path was not specified.'
)
# Check and populate quantization component spec.
_populate_quantization_component_spec(
quantization_options.quantization_method
)
# Verify and populate unit-wise quantization specs.
_populate_unitwise_quantization_specs(quantization_options)
if (
quantization_options.quantization_method.preset_method
== _PresetMethod.METHOD_STATIC_RANGE_INT8
):
# Check and populate calibration options.
_populate_calibration_options(quantization_options)
@tf_export.tf_export('quantization.experimental.quantize_saved_model')
def quantize(
saved_model_path: str,
output_directory: Optional[str] = None,
quantization_options: Optional[_QuantizationOptions] = None,
representative_dataset: Optional[
repr_dataset.RepresentativeDatasetOrMapping
] = None,
*,
overwrite_output_directory: bool = False,
) -> autotrackable.AutoTrackable:
"""Quantizes the SavedModel with the given quantization options.
Example usage:
```python
# Quantizing a model trained with QAT.
quantization_options = tf.quantization.experimental.QuantizationOptions(
signature_keys=['your_signature_key'],
)
tf.quantization.experimental.quantize_saved_model(
'/tmp/input_model',
'/tmp/output_model',
quantization_options=quantization_options,
)
# When quantizing a model trained without QAT (Post-Training Quantization),
# a representative dataset is required.
representative_dataset = [{"input": tf.random.uniform(shape=(3, 3))}
for _ in range(256)]
tf.quantization.experimental.quantize_saved_model(
'/tmp/input_model',
'/tmp/output_model',
quantization_options=quantization_options,
representative_dataset={'your_signature_key': representative_dataset},
)
# In addition to preset quantization methods, fine-grained control of
# quantization for each component is also supported.
_QuantizationComponentSpec = (
tf.quantization.experimental.QuantizationComponentSpec
)
quantization_options = tf.quantization.experimental.QuantizationOptions(
signature_keys=['your_signature_key'],
quantization_method=tf.quantization.experimental.QuantizationMethod(
quantization_component_specs=[
_QuantizationComponentSpec(
quantization_component=(
_QuantizationComponentSpec.COMPONENT_ACTIVATION
),
tensor_type=_QuantizationComponentSpec.TENSORTYPE_INT_8,
)
]
)
)
tf.quantization.experimental.quantize_saved_model(
'/tmp/input_model',
'/tmp/output_model',
quantization_options=quantization_options,
)
```
Args:
saved_model_path: Path to the saved model. When representative_dataset is
not provided, this should be a model trained with QAT.
output_directory: The path to save the output SavedModel. Set
`overwrite_output_directory` to `True` to overwrite any existing contents
in the directory if not empty.
quantization_options: A set of options for quantization. If None, it uses
post-training static range quantization with XLA opset by default.
representative_dataset: an iterator that returns a dictionary of {input_key:
input_value} or a map from signature key to a dictionary of {input_key:
input_value} that feeds calibration data for quantizing model. The
representative should be provided when the model is a PTQ model. It can be
provided either via this parameter or via the `representative_datasets`
field in `QuantizationOptions`.
overwrite_output_directory: If set to true, overwrites the output directory
iff it isn't empty. The default value is false.
Returns:
A SavedModel object with TF quantization applied, or None if no quantization
is performed.
Raises:
ValueError: When 1) representative_dataset is not provided for non QAT model
for enabling static range quantization, 2) invalid value is provided as
a quantization method, or 3) provide representative dataset via both
argument and QuantizationOptions.
ValueError: When the specified quantization method is not yet supported.
"""
_verify_output_dir(output_directory, overwrite_output_directory)
# Set default values for None arguments.
if output_directory is None:
output_directory = tempfile.mkdtemp()
if quantization_options is None:
quantization_options = _QuantizationOptions()
_populate_quantization_options_default_values(quantization_options)
method: _QuantizationMethod = quantization_options.quantization_method
if (
method.preset_method == _PresetMethod.METHOD_STATIC_RANGE_INT8
or method.preset_method == _PresetMethod.METHOD_NO_QUANTIZE
):
return _static_range_quantize(
saved_model_path,
output_directory,
quantization_options,
representative_dataset,
)
elif method.preset_method == _PresetMethod.METHOD_DYNAMIC_RANGE_INT8:
return _dynamic_range_quantize(
saved_model_path,
output_directory,
quantization_options,
)
elif (
method.preset_method == _PresetMethod.METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8
):
return _weight_only_quantize(
saved_model_path,
output_directory,
quantization_options,
)
else:
raise ValueError(
'Quantization method {method.preset_method} is not supported.'
)
@@ -0,0 +1,402 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines types required for representative datasets for quantization."""
from collections.abc import Collection, Sized
import os
from typing import Iterable, Mapping, Optional, Union
import numpy as np
from tensorflow.compiler.mlir.quantization.tensorflow import quantization_options_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.python.client import session
from tensorflow.python.data.ops import readers
from tensorflow.python.eager import context
from tensorflow.python.framework import tensor_util
from tensorflow.python.lib.io import python_io
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.types import core
from tensorflow.python.util import tf_export
# A representative sample is a map of: input_key -> input_value.
# Ex.: {'dense_input': tf.constant([1, 2, 3])}
# Ex.: {'x1': np.ndarray([4, 5, 6]}
RepresentativeSample = Mapping[str, core.TensorLike]
# A representative dataset is an iterable of representative samples.
RepresentativeDataset = Iterable[RepresentativeSample]
# A type representing a map from: signature key -> representative dataset.
# Ex.: {'serving_default': [tf.constant([1, 2, 3]), tf.constant([4, 5, 6])],
# 'other_signature_key': [tf.constant([[2, 2], [9, 9]])]}
RepresentativeDatasetMapping = Mapping[str, RepresentativeDataset]
# A type alias expressing that it can be either a RepresentativeDataset or
# a mapping of signature key to RepresentativeDataset.
RepresentativeDatasetOrMapping = Union[
RepresentativeDataset, RepresentativeDatasetMapping
]
# Type aliases for quantization_options_pb2 messages.
_RepresentativeDataSample = quantization_options_pb2.RepresentativeDataSample
_RepresentativeDatasetFile = quantization_options_pb2.RepresentativeDatasetFile
class RepresentativeDatasetSaver:
"""Representative dataset saver.
Exposes a single method `save` that saves the provided representative dataset
into files.
This is useful when you would like to keep a snapshot of your representative
dataset at a file system or when you need to pass the representative dataset
as files.
"""
def save(
self, representative_dataset: RepresentativeDatasetMapping
) -> Mapping[str, _RepresentativeDatasetFile]:
"""Saves the representative dataset.
Args:
representative_dataset: RepresentativeDatasetMapping which is a
signature_def_key -> representative dataset mapping.
"""
raise NotImplementedError('Method "save" is not implemented.')
@tf_export.tf_export(
'quantization.experimental.TfRecordRepresentativeDatasetSaver'
)
class TfRecordRepresentativeDatasetSaver(RepresentativeDatasetSaver):
"""Representative dataset saver in TFRecord format.
Saves representative datasets for quantization calibration in TFRecord format.
The samples are serialized as `RepresentativeDataSample`.
The `save` method return a signature key to `RepresentativeDatasetFile` map,
which can be used for QuantizationOptions.
Example usage:
```python
# Creating the representative dataset.
representative_dataset = [{"input": tf.random.uniform(shape=(3, 3))}
for _ in range(256)]
# Saving to a TFRecord file.
dataset_file_map = (
tf.quantization.experimental.TfRecordRepresentativeDatasetSaver(
path_map={'serving_default': '/tmp/representative_dataset_path'}
).save({'serving_default': representative_dataset})
)
# Using in QuantizationOptions.
quantization_options = tf.quantization.experimental.QuantizationOptions(
signature_keys=['serving_default'],
representative_datasets=dataset_file_map,
)
tf.quantization.experimental.quantize_saved_model(
'/tmp/input_model',
'/tmp/output_model',
quantization_options=quantization_options,
)
```
"""
def __init__(
self,
path_map: Mapping[str, os.PathLike[str]],
expected_input_key_map: Optional[Mapping[str, Collection[str]]] = None,
):
"""Initializes TFRecord represenatative dataset saver.
Args:
path_map: Signature def key -> path mapping. Each path is a TFRecord file
to which a `RepresentativeDataset` is saved. The signature def keys
should be a subset of the `SignatureDef` keys of the
`representative_dataset` argument of the `save()` call.
expected_input_key_map: Signature def key -> expected input keys. If set,
validate that the sample has same set of input keys before saving.
Raises:
KeyError: If path_map and expected_input_key_map have different keys.
"""
self.path_map: Mapping[str, os.PathLike[str]] = path_map
self.expected_input_key_map: Mapping[str, Collection[str]] = {}
if expected_input_key_map is not None:
if set(path_map.keys()) != set(expected_input_key_map.keys()):
raise KeyError(
'The `path_map` and `expected_input_key_map` should have the same'
' set of keys.'
)
self.expected_input_key_map = expected_input_key_map
def _save_tf_record_dataset(
self,
repr_ds: RepresentativeDataset,
signature_def_key: str,
) -> _RepresentativeDatasetFile:
"""Saves `repr_ds` to a TFRecord file.
Each sample in `repr_ds` is serialized as `RepresentativeDataSample`.
Args:
repr_ds: `RepresentativeDataset` to save.
signature_def_key: The signature def key associated with `repr_ds`.
Returns:
a RepresentativeDatasetFile instance contains the path to the saved file.
Raises:
KeyError: If the set of input keys in the dataset samples doesn't match
the set of expected input keys.
"""
# When running in graph mode (TF1), tf.Tensor types should be converted to
# numpy ndarray types to be compatible with `make_tensor_proto`.
if not context.executing_eagerly():
with session.Session() as sess:
repr_ds = replace_tensors_by_numpy_ndarrays(repr_ds, sess)
expected_input_keys = self.expected_input_key_map.get(
signature_def_key, None
)
tfrecord_file_path = self.path_map[signature_def_key]
with python_io.TFRecordWriter(tfrecord_file_path) as writer:
for repr_sample in repr_ds:
if (
expected_input_keys is not None
and set(repr_sample.keys()) != expected_input_keys
):
raise KeyError(
'Invalid input keys for representative sample. The function'
f' expects input keys of: {set(expected_input_keys)}. Got:'
f' {set(repr_sample.keys())}. Please provide correct input keys'
' for representative samples.'
)
sample = _RepresentativeDataSample()
for input_name, input_value in repr_sample.items():
sample.tensor_proto_inputs[input_name].CopyFrom(
tensor_util.make_tensor_proto(input_value)
)
writer.write(sample.SerializeToString())
logging.info(
'Saved representative dataset for signature def: %s to: %s',
signature_def_key,
tfrecord_file_path,
)
return _RepresentativeDatasetFile(
tfrecord_file_path=str(tfrecord_file_path)
)
def save(
self, representative_dataset: RepresentativeDatasetMapping
) -> Mapping[str, _RepresentativeDatasetFile]:
"""Saves the representative dataset.
Args:
representative_dataset: Signature def key -> representative dataset
mapping. Each dataset is saved in a separate TFRecord file whose path
matches the signature def key of `path_map`.
Raises:
ValueError: When the signature def key in `representative_dataset` is not
present in the `path_map`.
Returns:
A map from signature key to the RepresentativeDatasetFile instance
contains the path to the saved file.
"""
dataset_file_map = {}
for signature_def_key, repr_ds in representative_dataset.items():
if signature_def_key not in self.path_map:
raise ValueError(
'SignatureDef key does not exist in the provided path_map:'
f' {signature_def_key}'
)
dataset_file_map[signature_def_key] = self._save_tf_record_dataset(
repr_ds, signature_def_key
)
return dataset_file_map
class RepresentativeDatasetLoader:
"""Representative dataset loader.
Exposes the `load` method that loads the representative dataset from files.
"""
def load(self) -> RepresentativeDatasetMapping:
"""Loads the representative datasets.
Returns:
representative dataset mapping: A loaded signature def key ->
representative mapping.
"""
raise NotImplementedError('Method "load" is not implemented.')
class TfRecordRepresentativeDatasetLoader(RepresentativeDatasetLoader):
"""TFRecord representative dataset loader.
Loads representative dataset stored in TFRecord files.
"""
def __init__(
self,
dataset_file_map: Mapping[str, _RepresentativeDatasetFile],
) -> None:
"""Initializes TFRecord represenatative dataset loader.
Args:
dataset_file_map: Signature key -> `RepresentativeDatasetFile` mapping.
Raises:
DecodeError: If the sample is not RepresentativeDataSample.
"""
self.dataset_file_map = dataset_file_map
def _load_tf_record(self, tf_record_path: str) -> RepresentativeDataset:
"""Loads TFRecord containing samples of type`RepresentativeDataSample`."""
samples = []
with context.eager_mode():
for sample_bytes in readers.TFRecordDatasetV2(filenames=[tf_record_path]):
sample_proto = _RepresentativeDataSample.FromString(
sample_bytes.numpy()
)
sample = {}
for input_key, tensor_proto in sample_proto.tensor_proto_inputs.items():
sample[input_key] = tensor_util.MakeNdarray(tensor_proto)
samples.append(sample)
return samples
def load(self) -> RepresentativeDatasetMapping:
"""Loads the representative datasets.
Returns:
representative dataset mapping: A signature def key -> representative
mapping. The loader loads `RepresentativeDataset` for each path in
`self.dataset_file_map` and associates the loaded dataset to the
corresponding signature def key.
"""
repr_dataset_map = {}
for signature_def_key, dataset_file in self.dataset_file_map.items():
if dataset_file.HasField('tfrecord_file_path'):
repr_dataset_map[signature_def_key] = self._load_tf_record(
dataset_file.tfrecord_file_path
)
else:
raise ValueError('Unsupported Representative Dataset filetype')
return repr_dataset_map
def replace_tensors_by_numpy_ndarrays(
repr_ds: RepresentativeDataset, sess: session.Session
) -> RepresentativeDataset:
"""Replaces tf.Tensors in samples by their evaluated numpy arrays.
Note: This should be run in graph mode (default in TF1) only.
Args:
repr_ds: Representative dataset to replace the tf.Tensors with their
evaluated values. `repr_ds` is iterated through, so it may not be reusable
(e.g. if it is a generator object).
sess: Session instance used to evaluate tf.Tensors.
Returns:
The new representative dataset where each tf.Tensor is replaced by its
evaluated numpy ndarrays.
"""
new_repr_ds = []
for sample in repr_ds:
new_sample = {}
for input_key, input_data in sample.items():
# Evaluate the Tensor to get the actual value.
if isinstance(input_data, core.Tensor):
input_data = input_data.eval(session=sess)
new_sample[input_key] = input_data
new_repr_ds.append(new_sample)
return new_repr_ds
def get_num_samples(repr_ds: RepresentativeDataset) -> Optional[int]:
"""Returns the number of samples if known.
Args:
repr_ds: Representative dataset.
Returns:
Returns the total number of samples in `repr_ds` if it can be determined
without iterating the entier dataset. Returns None iff otherwise. When it
returns None it does not mean the representative dataset is infinite or it
is malformed; it simply means the size cannot be determined without
iterating the whole dataset.
"""
if isinstance(repr_ds, Sized):
try:
return len(repr_ds)
except Exception as ex: # pylint: disable=broad-except
# There are some cases where calling __len__() raises an exception.
# Handle this as if the size is unknown.
logging.info('Cannot determine the size of the dataset (%s).', ex)
return None
else:
return None
def create_feed_dict_from_input_data(
input_data: RepresentativeSample,
signature_def: meta_graph_pb2.SignatureDef,
) -> Mapping[str, np.ndarray]:
"""Constructs a feed_dict from input data.
Note: This function should only be used in graph mode.
This is a helper function that converts an 'input key -> input value' mapping
to a feed dict. A feed dict is an 'input tensor name -> input value' mapping
and can be directly passed to the `feed_dict` argument of `sess.run()`.
Args:
input_data: Input key -> input value mapping. The input keys should match
the input keys of `signature_def`.
signature_def: A SignatureDef representing the function that `input_data` is
an input to.
Returns:
Feed dict, which is intended to be used as input for `sess.run`. It is
essentially a mapping: input tensor name -> input value. Note that the input
value in the feed dict is not a `Tensor`.
"""
feed_dict = {}
for input_key, input_value in input_data.items():
input_tensor_name = signature_def.inputs[input_key].name
value = input_value
if isinstance(input_value, core.Tensor):
# Take the data out of the tensor.
value = input_value.eval()
feed_dict[input_tensor_name] = value
return feed_dict
@@ -0,0 +1,346 @@
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Defines utilities involving SavedModel."""
from typing import Collection, Dict, Mapping, Optional, Sequence
from absl import logging
# pylint: disable=g-importing-member
from google.protobuf.any_pb2 import Any
# pylint: enable=g-importing-member
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import meta_graph_pb2
from tensorflow.core.protobuf import saver_pb2
from tensorflow.python.client import session
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.lib.io import file_io
from tensorflow.python.saved_model import builder
from tensorflow.python.saved_model import constants as saved_model_constants
from tensorflow.python.saved_model import loader_impl as saved_model_loader
from tensorflow.python.saved_model import tag_constants
from tensorflow.python.training import saver
# Mapping of signature def key -> SignatureDef.
_SignatureDefMap = Mapping[str, meta_graph_pb2.SignatureDef]
def get_signatures_from_saved_model(
saved_model_path: str,
signature_keys: Optional[Sequence[str]] = None,
tags: Optional[Collection[str]] = None,
) -> Dict[str, meta_graph_pb2.SignatureDef]:
"""Gets a map from signature keys to their SignatureDef.
Args:
saved_model_path: Path to the saved model.
signature_keys: List of keys identifying SignatureDef to retrieve. If None,
retrieve all except the init signature.
tags: Set of tags identifying the MetaGraphDef within the SavedModel.
Returns:
A map from signature_key to its SignatureDef.
"""
if tags is None:
tags = {tag_constants.SERVING}
loader = saved_model_loader.SavedModelLoader(saved_model_path)
meta_graphdef = loader.get_meta_graph_def_from_tags(tags)
signatures = {}
for key, signature_def in meta_graphdef.signature_def.items():
if key == saved_model_constants.INIT_OP_SIGNATURE_KEY:
continue
if signature_keys is not None and key not in signature_keys:
continue
signatures[key] = signature_def
return signatures
def _restore_output_tensor_names(
graph_def: graph_pb2.GraphDef,
) -> graph_pb2.GraphDef:
"""Restores the output tensor names of the converted model.
During the conversion, the output tensor names of the original model are
embedded in the `tf_saved_model.index_path` attribute of the RetVal nodes and
might become the name of Retval nodes as well (with an index suffix if there
are multiple output tensors from one node). Since Retval nodes are not used in
SavedModel, this function removes them and restore the names to the actual
output tensors.
Args:
graph_def: the converted GraphDef.
Returns:
The GraphDef with Retval nodes removed and output tensor names restored.
"""
output_renaming_map = {}
with session.Session(graph=ops.Graph()):
importer.import_graph_def(graph_def, name='')
graph = ops.get_default_graph()
for op in graph.get_operations():
if op.type == '_Retval':
expected_node_name = op.name
if op.get_attr('tf_saved_model.index_path') is not None:
index_path_name = op.get_attr('tf_saved_model.index_path')[0]
index_path_name = index_path_name.decode('utf-8').split(':')[0]
try:
# Only use the index_path name if it points to a Retval node.
index_path_node = graph.get_operation_by_name(index_path_name)
if index_path_node.type == '_Retval':
expected_node_name = index_path_name
except KeyError:
pass
retval_input_node_name = op.inputs[0].op.name
output_renaming_map[retval_input_node_name] = expected_node_name
for node in reversed(graph_def.node):
if node.name in output_renaming_map:
node.name = output_renaming_map[node.name]
elif node.op == '_Retval':
graph_def.node.remove(node)
else:
# Update the inputs referring to the pre-renaming node.
for idx, input_name in enumerate(node.input):
if input_name in output_renaming_map:
node.input[idx] = output_renaming_map[input_name]
# Update the control inputs referring to the pre-renaming node.
updating_inputs = []
for input_name in reversed(node.input):
if input_name.startswith('^') and input_name[1:] in output_renaming_map:
updating_inputs.append(input_name[1:])
node.input.remove(input_name)
for updating_input in updating_inputs:
node.input.append('^' + output_renaming_map[updating_input])
return graph_def
def create_empty_output_dir(
output_directory: str, overwrite: bool = True
) -> None:
"""Creates the `output_directory`.
If `output_directory` already exists, it recursively deletes all contents
inside the directory.
Also creates the parent & intermediate directories.
Args:
output_directory: Output directory.
overwrite: Where to clean the output directory if exists.
"""
if overwrite and file_io.file_exists_v2(output_directory):
logging.info(
'Deleting existing output directory: %s .',
output_directory,
)
file_io.delete_recursively_v2(output_directory)
file_io.recursive_create_dir_v2(output_directory)
def _validate_signatures(
signature_def_map: _SignatureDefMap, exported_graph: ops.Graph
) -> _SignatureDefMap:
"""Validates if the tensor names in signatures are consistent with the graph.
This function checks if the input and output tensor names in the signatures
exist if the graph. The output tensor names might change during conversion,
we try to fix that with `_restore_output_tensor_names`. Besides, if there
are duplicated tensor names, they we will be prefixed with the signature name.
However, if that doesn't work the signatures can't be used with the converted
graph.
Args:
signature_def_map: the signatures to validate.
exported_graph: The PTQ-exported GraphDef.
Returns:
The signatures with tensor names prefixed with signature name if necessary.
Raises:
ValueError: Iff the signatures are not consistent with the graph.
"""
for signature_key, signature_def in signature_def_map.items():
for tensor_info in signature_def.inputs.values():
try:
exported_graph.get_tensor_by_name(tensor_info.name)
except KeyError as exc:
try:
prefixed_name = signature_key + '_' + tensor_info.name
exported_graph.get_tensor_by_name(prefixed_name)
tensor_info.name = prefixed_name
except KeyError:
raise ValueError(
'Cannot find the input tensor with name %s in the graph.'
% tensor_info.name
) from exc
for tensor_info in signature_def.outputs.values():
try:
exported_graph.get_tensor_by_name(tensor_info.name)
except KeyError as exc:
try:
prefixed_name = signature_key + '_' + tensor_info.name
exported_graph.get_tensor_by_name(prefixed_name)
tensor_info.name = prefixed_name
except KeyError:
raise ValueError(
'Cannot find the output tensor with name %s in the graph.'
% tensor_info.name
) from exc
return signature_def_map
def _find_op(
graph: ops.Graph, op_name: Optional[str]
) -> Optional[ops.Operation]:
"""Finds the operation with `op_name`.
Args:
graph: The graph to find from.
op_name: Name of the node.
Returns:
The operation that corresponds to `op_name`. Returns None iff op_name is an
empty string or None.
Raises:
ValueError: `op_name` is malformed.
"""
if not op_name:
return None
init_op = graph.get_operation_by_name(op_name)
logging.debug('Op found in the graph: %s', op_name)
return init_op
def _save_function_alias(
saved_model_dir: str,
tags: Collection[str],
function_aliases: Mapping[str, str],
) -> None:
"""Saves the function alias to the SavedModel.
SavedModelBuilder (TF1 saved model saver) does not support saving function
aliases, so this function loads the SavedModel proto and adds the
`function_aliases` field.
Args:
saved_model_dir: Path to the saved model directory.
tags: A collection of tags to specify the meta graph.
function_aliases: Function name -> function alias mapping.
"""
loader = saved_model_loader.SavedModelLoader(saved_model_dir)
meta_graph_def = loader.get_meta_graph_def_from_tags(tags)
for function_name, function_alias in function_aliases.items():
meta_graph_def.meta_info_def.function_aliases[function_name] = (
function_alias
)
saved_model_proto_serialized = loader.saved_model.SerializeToString()
# TODO(b/266015731): Also update and set the SavedModel fingerprint.
path = file_io.join(
saved_model_dir, saved_model_constants.SAVED_MODEL_FILENAME_PB
)
file_io.atomic_write_string_to_file(path, saved_model_proto_serialized)
def save_model_v1(
graph_def: graph_pb2.GraphDef,
output_dir: str,
signature_def_map: _SignatureDefMap,
tags: Collection[str],
init_op_name: Optional[str] = None,
saver_def: Optional[saver_pb2.SaverDef] = None,
checkpoint_dir: Optional[str] = None,
function_aliases: Optional[Mapping[str, str]] = None,
asset_file_defs: Sequence[meta_graph_pb2.AssetFileDef] = (),
) -> None:
"""Saves the model.
Saves the provided graph def as SavedModel.
Uses TF1 SavedModel semantics (i.e. no object graph).
Args:
graph_def: Graph to save.
output_dir: Output directory for the SavedModel.
signature_def_map: Mapping of signature def key -> SignatureDef.
tags: Tags for the meta graph def.
init_op_name: Name of the node for initialization.
saver_def: `saver_pb2.SaverDef` to create a `saver.Saver` from. The created
saver will be used to save and load variables. This may be `None` if no
variables exist in the graph.
checkpoint_dir: Path to checkpoint file where variable values are saved.
function_aliases: Function name -> function alias mapping.
asset_file_defs: `AssetFileDef`s that associates the asset files and the
name of the tensors to which the asset file names should be fed. The
caller should make sure the asset files exist in the output saved model
directory.
Raises:
ValueError iff the graph does not contain a valid signature or the file
prefix tensor is not found in the graph.
"""
create_empty_output_dir(output_dir)
v1_builder = builder.SavedModelBuilder(output_dir)
graph_def = _restore_output_tensor_names(graph_def)
with session.Session(graph=ops.Graph()) as sess:
importer.import_graph_def(graph_def, name='')
signature_def_map = _validate_signatures(
signature_def_map, ops.get_default_graph()
)
# Add `AssetFileDef`s to the collection so that correct values are fed to
# the tensors that accept asset file paths.
for asset_file_def in asset_file_defs:
asset_any_proto = Any()
asset_any_proto.Pack(asset_file_def)
ops.add_to_collection(
saved_model_constants.ASSETS_KEY,
asset_any_proto,
)
model_saver = None
# If `saver_def` is not None, it means there are variables in the graph.
if saver_def:
model_saver = saver.Saver(saver_def=saver_def)
logging.info('Saver created with SaverDef: %s', saver_def)
# Variables should be restored once before exporting as saved model
# because the variables are not initialized when the GraphDef was
# imported.
model_saver.restore(sess, checkpoint_dir)
v1_builder.add_meta_graph_and_variables(
sess,
tags,
signature_def_map=signature_def_map,
main_op=_find_op(sess.graph, op_name=init_op_name),
saver=model_saver,
)
v1_builder.save()
if function_aliases:
_save_function_alias(output_dir, tags, function_aliases)
@@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from tensorflow.compiler.mlir.quantization.stablehlo import quantization_config_pb2 as tensorflow_dot_compiler_dot_mlir_dot_quantization_dot_stablehlo_dot_quantization__config__pb2
from tensorflow.core.framework import tensor_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\nKtensorflow/compiler/mlir/quantization/tensorflow/quantization_options.proto\x12\x17tensorflow.quantization\x1aItensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto\x1a&tensorflow/core/framework/tensor.proto\"\xed\x02\n\x12QuantizationMethod\x12O\n\rpreset_method\x18\x04 \x01(\x0e\x32\x38.tensorflow.quantization.QuantizationMethod.PresetMethod\x12X\n\x1cquantization_component_specs\x18\x03 \x03(\x0b\x32\x32.tensorflow.quantization.QuantizationComponentSpec\"\xa5\x01\n\x0cPresetMethod\x12\x16\n\x12METHOD_UNSPECIFIED\x10\x00\x12\x16\n\x12METHOD_NO_QUANTIZE\x10\x01\x12\x1c\n\x18METHOD_STATIC_RANGE_INT8\x10\x02\x12\x1d\n\x19METHOD_DYNAMIC_RANGE_INT8\x10\x03\x12(\n$METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8\x10\x04J\x04\x08\x01\x10\x03\"\xbe\x03\n\x19QuantizationComponentSpec\x12h\n\x16quantization_component\x18\x01 \x01(\x0e\x32H.tensorflow.quantization.QuantizationComponentSpec.QuantizationComponent\x12R\n\x0btensor_type\x18\x02 \x01(\x0e\x32=.tensorflow.quantization.QuantizationComponentSpec.TensorType\"v\n\x15QuantizationComponent\x12\x19\n\x15\x43OMPONENT_UNSPECIFIED\x10\x00\x12\x18\n\x14\x43OMPONENT_ACTIVATION\x10\x01\x12\x14\n\x10\x43OMPONENT_WEIGHT\x10\x02\x12\x12\n\x0e\x43OMPONENT_BIAS\x10\x03\"k\n\nTensorType\x12\x1a\n\x16TENSORTYPE_UNSPECIFIED\x10\x00\x12\x14\n\x10TENSORTYPE_INT_4\x10\x01\x12\x14\n\x10TENSORTYPE_INT_8\x10\x02\x12\x15\n\x11TENSORTYPE_INT_32\x10\x03\"\x87\x02\n\x18UnitWiseQuantizationSpec\x12P\n\x04unit\x18\x05 \x03(\x0b\x32\x42.tensorflow.quantization.UnitWiseQuantizationSpec.QuantizationUnit\x12H\n\x13quantization_method\x18\x06 \x01(\x0b\x32+.tensorflow.quantization.QuantizationMethod\x1aI\n\x10QuantizationUnit\x12\x0f\n\x07op_type\x18\x01 \x01(\t\x12\x11\n\tnode_name\x18\x02 \x01(\t\x12\x11\n\tfunc_name\x18\x03 \x01(\tJ\x04\x08\x01\x10\x05\"\xd4\x01\n\x18RepresentativeDataSample\x12\x65\n\x13tensor_proto_inputs\x18\x02 \x03(\x0b\x32H.tensorflow.quantization.RepresentativeDataSample.TensorProtoInputsEntry\x1aQ\n\x16TensorProtoInputsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12&\n\x05value\x18\x02 \x01(\x0b\x32\x17.tensorflow.TensorProto:\x02\x38\x01\"I\n\x19RepresentativeDatasetFile\x12\x1c\n\x12tfrecord_file_path\x18\x01 \x01(\tH\x00\x42\x0e\n\x0c\x64\x61taset_file\"\xca\x07\n\x13QuantizationOptions\x12H\n\x13quantization_method\x18\x01 \x01(\x0b\x32+.tensorflow.quantization.QuantizationMethod\x12.\n\x06op_set\x18\x02 \x01(\x0e\x32\x1e.tensorflow.quantization.OpSet\x12W\n\x1cunit_wise_quantization_specs\x18\x11 \x03(\x0b\x32\x31.tensorflow.quantization.UnitWiseQuantizationSpec\x12\x0c\n\x04tags\x18\x05 \x03(\t\x12\x16\n\x0esignature_keys\x18\x06 \x03(\t\x12i\n\x17representative_datasets\x18\x07 \x03(\x0b\x32H.tensorflow.quantization.QuantizationOptions.RepresentativeDatasetsEntry\x12$\n\x1cmin_num_elements_for_weights\x18\x08 \x01(\x03\x12!\n\x14\x66reeze_all_variables\x18\t \x01(\x08H\x00\x88\x01\x01\x12,\n\x1f\x65nable_per_channel_quantization\x18\n \x01(\x08H\x01\x88\x01\x01\x12 \n\x18\x65nable_two_input_tensors\x18\x0b \x01(\x08\x12-\n%experimental_enable_tpu_model_support\x18\x0c \x01(\x08\x12!\n\x19\x65nable_legacy_weight_only\x18\r \x01(\x08\x12$\n\x1c\x66orce_graph_mode_calibration\x18\x0e \x01(\x08\x12G\n\x13\x63\x61libration_options\x18\x0f \x01(\x0b\x32*.stablehlo.quantization.CalibrationOptions\x12?\n\x0f\x64\x65\x62ugger_config\x18\x10 \x01(\x0b\x32&.stablehlo.quantization.DebuggerConfig\x1aq\n\x1bRepresentativeDatasetsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\x41\n\x05value\x18\x02 \x01(\x0b\x32\x32.tensorflow.quantization.RepresentativeDatasetFile:\x02\x38\x01\x42\x17\n\x15_freeze_all_variablesB\"\n _enable_per_channel_quantizationJ\x04\x08\x03\x10\x04*V\n\x05OpSet\x12\x16\n\x12OP_SET_UNSPECIFIED\x10\x00\x12\x06\n\x02TF\x10\x01\x12\x07\n\x03XLA\x10\x02\x12\x15\n\x11UNIFORM_QUANTIZED\x10\x03\x12\r\n\tSTABLEHLO\x10\x04\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.compiler.mlir.quantization.tensorflow.quantization_options_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_REPRESENTATIVEDATASAMPLE_TENSORPROTOINPUTSENTRY']._loaded_options = None
_globals['_REPRESENTATIVEDATASAMPLE_TENSORPROTOINPUTSENTRY']._serialized_options = b'8\001'
_globals['_QUANTIZATIONOPTIONS_REPRESENTATIVEDATASETSENTRY']._loaded_options = None
_globals['_QUANTIZATIONOPTIONS_REPRESENTATIVEDATASETSENTRY']._serialized_options = b'8\001'
_globals['_OPSET']._serialized_start=2565
_globals['_OPSET']._serialized_end=2651
_globals['_QUANTIZATIONMETHOD']._serialized_start=220
_globals['_QUANTIZATIONMETHOD']._serialized_end=585
_globals['_QUANTIZATIONMETHOD_PRESETMETHOD']._serialized_start=414
_globals['_QUANTIZATIONMETHOD_PRESETMETHOD']._serialized_end=579
_globals['_QUANTIZATIONCOMPONENTSPEC']._serialized_start=588
_globals['_QUANTIZATIONCOMPONENTSPEC']._serialized_end=1034
_globals['_QUANTIZATIONCOMPONENTSPEC_QUANTIZATIONCOMPONENT']._serialized_start=807
_globals['_QUANTIZATIONCOMPONENTSPEC_QUANTIZATIONCOMPONENT']._serialized_end=925
_globals['_QUANTIZATIONCOMPONENTSPEC_TENSORTYPE']._serialized_start=927
_globals['_QUANTIZATIONCOMPONENTSPEC_TENSORTYPE']._serialized_end=1034
_globals['_UNITWISEQUANTIZATIONSPEC']._serialized_start=1037
_globals['_UNITWISEQUANTIZATIONSPEC']._serialized_end=1300
_globals['_UNITWISEQUANTIZATIONSPEC_QUANTIZATIONUNIT']._serialized_start=1221
_globals['_UNITWISEQUANTIZATIONSPEC_QUANTIZATIONUNIT']._serialized_end=1294
_globals['_REPRESENTATIVEDATASAMPLE']._serialized_start=1303
_globals['_REPRESENTATIVEDATASAMPLE']._serialized_end=1515
_globals['_REPRESENTATIVEDATASAMPLE_TENSORPROTOINPUTSENTRY']._serialized_start=1434
_globals['_REPRESENTATIVEDATASAMPLE_TENSORPROTOINPUTSENTRY']._serialized_end=1515
_globals['_REPRESENTATIVEDATASETFILE']._serialized_start=1517
_globals['_REPRESENTATIVEDATASETFILE']._serialized_end=1590
_globals['_QUANTIZATIONOPTIONS']._serialized_start=1593
_globals['_QUANTIZATIONOPTIONS']._serialized_end=2563
_globals['_QUANTIZATIONOPTIONS_REPRESENTATIVEDATASETSENTRY']._serialized_start=2383
_globals['_QUANTIZATIONOPTIONS_REPRESENTATIVEDATASETSENTRY']._serialized_end=2496
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,26 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""StableHLO Portable Python APIs.
This setup only exports the the StableHLO Portable C++ APIs, which have
signatures that do not rely on MLIR classes.
Exporting all of MLIR Python bindings to TF OSS has high maintenance
implications, especially given the frequency that TF updates the revision of
LLVM used.
"""
# pylint: disable=wildcard-import
from .stablehlo_extension import *
@@ -0,0 +1,127 @@
"""Python wrappers around TensorFlow ops.
This file is MACHINE GENERATED! Do not edit.
"""
import collections
from tensorflow.python import pywrap_tfe as pywrap_tfe
from tensorflow.python.eager import context as _context
from tensorflow.python.eager import core as _core
from tensorflow.python.eager import execute as _execute
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.security.fuzzing.py import annotation_types as _atypes
from tensorflow.python.framework import op_def_registry as _op_def_registry
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import op_def_library as _op_def_library
from tensorflow.python.util.deprecation import deprecated_endpoints
from tensorflow.python.util import dispatch as _dispatch
from tensorflow.python.util.tf_export import tf_export
from typing import TypeVar, List, Any
from typing_extensions import Annotated
@_dispatch.add_fallback_dispatch_list
@_dispatch.add_type_based_api_dispatcher
@tf_export('mlir_passthrough_op')
def mlir_passthrough_op(inputs, mlir_module: str, Toutputs, name=None):
r"""TODO: add doc.
Args:
inputs: A list of `Tensor` objects.
mlir_module: A `string`.
Toutputs: A list of `tf.DTypes`.
name: A name for the operation (optional).
Returns:
A list of `Tensor` objects of type `Toutputs`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "MlirPassthroughOp", name, inputs, "mlir_module", mlir_module,
"Toutputs", Toutputs)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
_result = _dispatcher_for_mlir_passthrough_op(
(inputs, mlir_module, Toutputs, name,), None)
if _result is not NotImplemented:
return _result
return mlir_passthrough_op_eager_fallback(
inputs, mlir_module=mlir_module, Toutputs=Toutputs, name=name,
ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
except (TypeError, ValueError):
_result = _dispatch.dispatch(
mlir_passthrough_op, (), dict(inputs=inputs,
mlir_module=mlir_module,
Toutputs=Toutputs, name=name)
)
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
return _result
raise
else:
_result = _dispatcher_for_mlir_passthrough_op(
(inputs, mlir_module, Toutputs, name,), None)
if _result is not NotImplemented:
return _result
# Add nodes to the TensorFlow graph.
mlir_module = _execute.make_str(mlir_module, "mlir_module")
if not isinstance(Toutputs, (list, tuple)):
raise TypeError(
"Expected list for 'Toutputs' argument to "
"'mlir_passthrough_op' Op, not %r." % Toutputs)
Toutputs = [_execute.make_type(_t, "Toutputs") for _t in Toutputs]
try:
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"MlirPassthroughOp", inputs=inputs, mlir_module=mlir_module,
Toutputs=Toutputs, name=name)
except (TypeError, ValueError):
_result = _dispatch.dispatch(
mlir_passthrough_op, (), dict(inputs=inputs,
mlir_module=mlir_module,
Toutputs=Toutputs, name=name)
)
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
return _result
raise
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("mlir_module", _op.get_attr("mlir_module"), "Tinputs",
_op.get_attr("Tinputs"), "Toutputs", _op.get_attr("Toutputs"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"MlirPassthroughOp", _inputs_flat, _attrs, _result)
return _result
MlirPassthroughOp = tf_export("raw_ops.MlirPassthroughOp")(_ops.to_raw_op(mlir_passthrough_op))
_dispatcher_for_mlir_passthrough_op = mlir_passthrough_op._tf_type_based_dispatcher.Dispatch
def mlir_passthrough_op_eager_fallback(inputs, mlir_module: str, Toutputs, name, ctx):
mlir_module = _execute.make_str(mlir_module, "mlir_module")
if not isinstance(Toutputs, (list, tuple)):
raise TypeError(
"Expected list for 'Toutputs' argument to "
"'mlir_passthrough_op' Op, not %r." % Toutputs)
Toutputs = [_execute.make_type(_t, "Toutputs") for _t in Toutputs]
_attr_Tinputs, inputs = _execute.convert_to_mixed_eager_tensors(inputs, ctx)
_inputs_flat = list(inputs)
_attrs = ("mlir_module", mlir_module, "Tinputs", _attr_Tinputs, "Toutputs",
Toutputs)
_result = _execute.execute(b"MlirPassthroughOp", len(Toutputs),
inputs=_inputs_flat, attrs=_attrs, ctx=ctx,
name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"MlirPassthroughOp", _inputs_flat, _attrs, _result)
return _result
@@ -0,0 +1,30 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# LINT.IfChange(savedmodel_to_stablehlo)
def savedmodel_to_stablehlo(
input_path: str,
exported_model_signatures: list[str] = ["serving_default"],
tag_names: list[str] = ["serve"],
input_arg_shapes_str: str = "",
) -> bytes: ...
# LINT.ThenChange()
# LINT.IfChange(tensorflow_module_to_stablehlo)
def tensorflow_module_to_stablehlo(
module: str,
input_arg_shapes_str: str = "",
) -> bytes: ...
# LINT.ThenChange()
@@ -0,0 +1,19 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
def get_linked_tensorrt_version() -> tuple[int, int, int]: ...
def get_loaded_tensorrt_version() -> tuple[int, int, int]: ...
def get_registered_op_converters() -> list[str]: ...
def is_tensorrt_enabled() -> bool: ...
@@ -0,0 +1,23 @@
"""Python wrappers around TensorFlow ops.
This file is MACHINE GENERATED! Do not edit.
"""
import collections
from tensorflow.python import pywrap_tfe as pywrap_tfe
from tensorflow.python.eager import context as _context
from tensorflow.python.eager import core as _core
from tensorflow.python.eager import execute as _execute
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.security.fuzzing.py import annotation_types as _atypes
from tensorflow.python.framework import op_def_registry as _op_def_registry
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import op_def_library as _op_def_library
from tensorflow.python.util.deprecation import deprecated_endpoints
from tensorflow.python.util import dispatch as _dispatch
from tensorflow.python.util.tf_export import tf_export
from typing import TypeVar, List, Any
from typing_extensions import Annotated
@@ -0,0 +1,728 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Experimental library that exposes XLA operations directly in TensorFlow.
It is sometimes useful to be able to build HLO programs directly from
TensorFlow. This file provides Tensorflow operators that mirror the semantics of
HLO operators as closely as possible.
Note: Most of the operators defined in this module are used by the jax2tf
converter (see go/jax2tf for details) and are used in SavedModel produced
by jax2tf. Hence, we need to maintain backwards compatibility for these
operators. Please reach out to the JAX team if you want to make changes.
"""
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
from tensorflow.compiler.xla import xla_data_pb2
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import bitwise_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_random_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import random_ops_util
from tensorflow.python.ops import special_math_ops
from tensorflow.python.ops.numpy_ops import np_utils
# TODO(phawkins): provide wrappers for all XLA operators. Currently the missing
# ops include:
# infeed/outfeed (available via tf.contrib.tpu)
# collectives, e.g., cross-replica-sum (available via tf.contrib.tpu)
# conditional
# gather/scatter
# collapse
# This file reuses builtin names (following XLA's names, so we can call things
# like xla.max), so we capture the builtin versions here.
# pylint: disable=redefined-builtin
_max = max
_min = min
_slice = slice # pylint: disable=invalid-name
constant = constant_op.constant
# Unary operators.
# For most arithmetic operators there is a TensorFlow operator
# that exactly corresponds to each XLA operator. Rather than defining
# XLA-specific variants, we reuse the corresponding TensorFlow operator.
# TODO(phawkins): It would be even better to have TensorFlow operators that 1:1
# wrap every HLO operator, because that would allow us to be confident that the
# semantics match.
def _unary_op(fn):
"""Wrapper that restricts `fn` to have the correct signature."""
def unary_op_wrapper(x, name=None):
return fn(x, name=name)
return unary_op_wrapper
abs = _unary_op(math_ops.abs)
# TODO(phawkins): implement clz.
conj = _unary_op(math_ops.conj)
cos = _unary_op(math_ops.cos)
ceil = _unary_op(math_ops.ceil)
digamma = _unary_op(math_ops.digamma)
erf = _unary_op(math_ops.erf)
erfc = _unary_op(math_ops.erfc)
erfinv = _unary_op(math_ops.erfinv)
ndtri = _unary_op(math_ops.ndtri)
exp = _unary_op(math_ops.exp)
expm1 = _unary_op(math_ops.expm1)
floor = _unary_op(math_ops.floor)
imag = _unary_op(math_ops.imag)
is_finite = _unary_op(math_ops.is_finite)
lgamma = _unary_op(math_ops.lgamma)
log = _unary_op(math_ops.log)
log1p = _unary_op(math_ops.log1p)
logical_not = _unary_op(math_ops.logical_not)
neg = _unary_op(math_ops.neg)
real = _unary_op(math_ops.real)
# TODO(phawkins): unlike xla::Round, this rounds to even instead of zero for
# numbers halfway between two integers.
round = _unary_op(math_ops.round)
sin = _unary_op(math_ops.sin)
sign = _unary_op(math_ops.sign)
tan = _unary_op(math_ops.tan)
tanh = _unary_op(math_ops.tanh)
# Bessel
bessel_i0e = _unary_op(special_math_ops.bessel_i0e)
bessel_i1e = _unary_op(special_math_ops.bessel_i1e)
# Binary operators
# The main difference between TensorFlow and XLA binary ops is the broadcasting
# semantics. TensorFlow uses Numpy-style broadcasting semantics, whereas XLA
# requires an explicit specification of which dimensions to broadcast if the
# arguments have different ranks.
def _broadcasting_binary_op(fn):
"""Wraps a binary Tensorflow operator and performs XLA-style broadcasting."""
def broadcasting_binary_op_wrapper(x, y, broadcast_dims=None, name=None):
"""Inner wrapper function."""
broadcast_dims = broadcast_dims or []
broadcast_dims = ops.convert_to_tensor(broadcast_dims, dtypes.int64)
# Rather than relying on having static shape information in the TensorFlow
# graph, we use an XlaBroadcastHelper op that can compute the correct shapes
# at JIT compilation time.
x, y = gen_xla_ops.xla_broadcast_helper(x, y, broadcast_dims)
return fn(x, y, name=name)
return broadcasting_binary_op_wrapper
# Map from TF signed types to TF unsigned types.
_SIGNED_TO_UNSIGNED_TABLE = {
dtypes.int8: dtypes.uint8,
dtypes.int16: dtypes.uint16,
dtypes.int32: dtypes.uint32,
dtypes.int64: dtypes.uint64,
}
# Map from TF unsigned types to TF signed types.
_UNSIGNED_TO_SIGNED_TABLE = {
dtypes.uint8: dtypes.int8,
dtypes.uint16: dtypes.int16,
dtypes.uint32: dtypes.int32,
dtypes.uint64: dtypes.int64,
}
def _shift_right_logical_helper(x, y, name=None):
"""Performs an integer right logical shift irrespective of input type."""
assert y.dtype == x.dtype
dtype = x.dtype
signed = dtype in _SIGNED_TO_UNSIGNED_TABLE
if signed:
unsigned_dtype = _SIGNED_TO_UNSIGNED_TABLE[dtype]
x = math_ops.cast(x, unsigned_dtype)
y = math_ops.cast(y, unsigned_dtype)
output = bitwise_ops.right_shift(x, y, name=name)
if signed:
output = math_ops.cast(output, dtype)
return output
def _shift_right_arithmetic_helper(x, y, name=None):
"""Performs an integer right arithmetic shift irrespective of input type."""
assert y.dtype == x.dtype
dtype = x.dtype
unsigned = dtype in _UNSIGNED_TO_SIGNED_TABLE
if unsigned:
signed_dtype = _UNSIGNED_TO_SIGNED_TABLE[dtype]
x = math_ops.cast(x, signed_dtype)
y = math_ops.cast(y, signed_dtype)
output = bitwise_ops.right_shift(x, y, name=name)
if unsigned:
output = math_ops.cast(output, dtype)
return output
add = _broadcasting_binary_op(math_ops.add)
sub = _broadcasting_binary_op(math_ops.sub)
mul = _broadcasting_binary_op(math_ops.mul)
div = _broadcasting_binary_op(math_ops.div)
rem = _broadcasting_binary_op(gen_math_ops.mod)
max = _broadcasting_binary_op(math_ops.maximum)
min = _broadcasting_binary_op(math_ops.minimum)
atan2 = _broadcasting_binary_op(math_ops.atan2)
complex = _broadcasting_binary_op(math_ops.complex)
logical_and = _broadcasting_binary_op(math_ops.logical_and)
logical_or = _broadcasting_binary_op(math_ops.logical_or)
logical_xor = _broadcasting_binary_op(math_ops.logical_xor)
eq = _broadcasting_binary_op(math_ops.equal)
ne = _broadcasting_binary_op(math_ops.not_equal)
ge = _broadcasting_binary_op(math_ops.greater_equal)
gt = _broadcasting_binary_op(math_ops.greater)
le = _broadcasting_binary_op(math_ops.less_equal)
lt = _broadcasting_binary_op(math_ops.less)
pow = _broadcasting_binary_op(math_ops.pow)
shift_left = _broadcasting_binary_op(bitwise_ops.left_shift)
shift_right_logical = _broadcasting_binary_op(_shift_right_logical_helper)
shift_right_arithmetic = _broadcasting_binary_op(_shift_right_arithmetic_helper)
igamma = _broadcasting_binary_op(math_ops.igamma)
igamma_grad_a = _broadcasting_binary_op(gen_math_ops.igamma_grad_a)
random_gamma_grad = _broadcasting_binary_op(gen_random_ops.random_gamma_grad)
igammac = _broadcasting_binary_op(math_ops.igammac)
polygamma = _broadcasting_binary_op(math_ops.polygamma)
zeta = _broadcasting_binary_op(math_ops.zeta)
def _binary_op(fn):
"""Wrapper that restricts `fn` to have the correct signature."""
def binary_op_wrapper(x, y, name=None):
return fn(x, y, name=name)
return binary_op_wrapper
transpose = _binary_op(array_ops.transpose)
rev = _binary_op(array_ops.reverse)
bitcast_convert_type = array_ops.bitcast
def broadcast(x, dims, name=None):
x = ops.convert_to_tensor(x)
shape = array_ops.concat(
[constant_op.constant(dims), array_ops.shape(x)], axis=0
)
return array_ops.broadcast_to(x, shape, name=name)
def clamp(a, x, b, name=None):
return min(max(a, x, name=name), b, name=name)
concatenate = array_ops.concat
def conv(
lhs,
rhs,
window_strides,
padding,
lhs_dilation,
rhs_dilation,
dimension_numbers,
feature_group_count=1,
precision_config=None,
preferred_element_type=None,
name=None,
use_v2=False,
batch_group_count=1,
):
"""Wraps the XLA ConvGeneralDilated operator.
ConvGeneralDilated is the most general form of XLA convolution and is
documented at
https://www.tensorflow.org/performance/xla/operation_semantics#conv_convolution
Args:
lhs: the input tensor
rhs: the kernel tensor
window_strides: the inter-window strides
padding: the padding to apply at the start and end of each input dimensions
lhs_dilation: dilation to apply between input elements
rhs_dilation: dilation to apply between kernel elements
dimension_numbers: a `ConvolutionDimensionNumbers` proto.
feature_group_count: number of feature groups for grouped convolution.
precision_config: a `xla.PrecisionConfig` proto.
preferred_element_type: the result `dtype`.
name: an optional name for the operator.
use_v2: an optional request to use the XlaConvV2 op even if not necessary.
batch_group_count: number of batch groups or grouped filters.
Returns:
A tensor representing the output of the convolution.
"""
precision_config_proto = ""
if precision_config:
precision_config_proto = precision_config.SerializeToString()
needs_v2 = (
preferred_element_type
or (lhs.dtype != rhs.dtype)
or batch_group_count > 1
)
if preferred_element_type is None:
preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
if needs_v2 or use_v2:
return gen_xla_ops.xla_conv_v2(
lhs,
rhs,
window_strides=window_strides,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
feature_group_count=feature_group_count,
batch_group_count=batch_group_count,
dimension_numbers=dimension_numbers.SerializeToString(),
precision_config=precision_config_proto,
preferred_element_type=preferred_element_type,
name=name,
)
return gen_xla_ops.xla_conv(
lhs,
rhs,
window_strides=window_strides,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
feature_group_count=feature_group_count,
dimension_numbers=dimension_numbers.SerializeToString(),
precision_config=precision_config_proto,
name=name,
)
convert_element_type = math_ops.cast
def dot(lhs, rhs, name=None):
return math_ops.tensordot(lhs, rhs, axes=1, name=name)
DotDimensionNumbers = xla_data_pb2.DotDimensionNumbers
PrecisionConfig = xla_data_pb2.PrecisionConfig
def dot_general(
lhs,
rhs,
dimension_numbers,
precision_config=None,
preferred_element_type=None,
name=None,
use_v2=False,
):
precision_config_proto = ""
if precision_config:
precision_config_proto = precision_config.SerializeToString()
needs_v2 = preferred_element_type or (lhs.dtype != rhs.dtype)
if preferred_element_type is None:
preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
if needs_v2 or use_v2:
return gen_xla_ops.xla_dot_v2(
lhs,
rhs,
dimension_numbers=dimension_numbers.SerializeToString(),
precision_config=precision_config_proto,
preferred_element_type=preferred_element_type,
name=name,
)
return gen_xla_ops.xla_dot(
lhs,
rhs,
dimension_numbers=dimension_numbers.SerializeToString(),
precision_config=precision_config_proto,
name=name,
)
def self_adjoint_eig(a, lower, max_iter, epsilon):
return gen_xla_ops.xla_self_adjoint_eig(a, lower, max_iter, epsilon)
def svd(a, max_iter, epsilon, precision_config=None):
precision_config_proto = ""
if precision_config:
precision_config_proto = precision_config.SerializeToString()
return gen_xla_ops.xla_svd(a, max_iter, epsilon, precision_config_proto)
dynamic_slice = gen_xla_ops.xla_dynamic_slice
dynamic_update_slice = gen_xla_ops.xla_dynamic_update_slice
einsum = gen_xla_ops.xla_einsum
# TODO(phawkins): generalize tf.pad to support interior padding, and then remove
# the XLA-specific pad operator.
pad = gen_xla_ops.xla_pad
def random_normal(mu, sigma, dims, name=None):
mu = ops.convert_to_tensor(mu)
return random_ops.random_normal(
dims, mean=mu, stddev=sigma, dtype=mu.dtype, name=name
)
def random_uniform(minval, maxval, dims, name=None):
minval = ops.convert_to_tensor(minval)
return random_ops.random_uniform(
dims, minval, maxval, dtype=minval.dtype, name=name
)
def rng_bit_generator(algorithm, initial_state, shape, dtype):
"""Stateless PRNG bit generator.
Wraps the XLA RngBitGenerator operator, documented at
https://www.tensorflow.org/performance/xla/operation_semantics#rngbitgenerator.
Args:
algorithm: The PRNG algorithm to use, one of tf.random.Algorithm.{PHILOX,
THREEFRY, AUTO_SELECT}.
initial_state: Initial state for the PRNG algorithm. For THREEFRY, it should
be a u64[2] and for PHILOX a u64[3].
shape: The output shape of the generated data.
dtype: The type of the tensor.
Returns:
a tuple with a new state and generated data of the given shape.
"""
alg_int = random_ops_util.convert_alg_to_int(algorithm)
return gen_xla_ops.xla_rng_bit_generator(
alg_int, initial_state, shape, dtype=dtype
)
recv = gen_xla_ops.xla_recv
reduce = gen_xla_ops.xla_reduce
variadic_reduce = gen_xla_ops.xla_variadic_reduce_v2
ops.no_gradient("XlaVariadicReduce")
def reduce_window(
operand,
init,
reducer,
window_dimensions,
window_strides=None,
base_dilations=None,
window_dilations=None,
padding=None,
name=None,
):
"""Wraps the XLA ReduceWindow operator.
ReduceWindow is documented at
https://www.tensorflow.org/performance/xla/operation_semantics#reducewindow .
Args:
operand: the input tensor
init: a scalar tensor representing the initial value for the reduction
reducer: a reduction function that combines a pair of scalars.
window_dimensions: shape of the window, as a list of integers
window_strides: inter-window strides, as a list of integers. Optional; if
omitted, defaults to strides of 1.
padding: padding to apply to 'operand'. List of (low, high) pairs of
integers that specify the padding to apply before and after each
dimension. Optional; if omitted, defaults to no padding.
name: the operator name, or None.
Returns:
A tensor that represents the output of the reduce_window operator.
"""
window_strides = window_strides or [1] * len(window_dimensions)
base_dilations = base_dilations or [1] * len(window_dimensions)
window_dilations = window_dilations or [1] * len(window_dimensions)
padding = padding or [(0, 0)] * len(window_dimensions)
return gen_xla_ops.xla_reduce_window(
input=operand,
init_value=init,
window_dimensions=window_dimensions,
window_strides=window_strides,
base_dilations=base_dilations,
window_dilations=window_dilations,
padding=padding,
computation=reducer,
name=name,
)
replica_id = gen_xla_ops.xla_replica_id
# Set a static bound for the given input value as a hint to Xla compiler,
# returns the same value.
# Usage:
# def f(t, p):
# p = xla.set_bound(p, 3) # Tells xla the constraint that p <= 3.
# return t[:p] # xla knows the bound of the slice is 3.
set_bound = gen_xla_ops.xla_set_bound
# Make a static dimension into a xla bounded dynamic dimension. The current
# static dimension size will become the bound and the second operand becomes the
# dynamic size of the dimension.
#
# This should mostly be used for testing.
#
# def f():
# array = tf.convert_to_tensor([[1, 2, 3, 4, 5]])
# # Tells xla the valid size of the array is 3.
# dim = 0
# p = xla_set_dynamic_dimension_size(array, dim, 3)
# assert(reduce_sum(p) == 6) # xla knows only the first 3 elements are valid.
set_dynamic_dimension_size = gen_xla_ops.xla_set_dynamic_dimension_size
# Inverse of xla_set_dynamic_dimension_size. Make an xla bounded dynamic
# dimension into a static dimension. The bound of the size of dimension
# `dim_index` becomes the static dimension size.
remove_dynamic_dimension_size = gen_xla_ops.xla_remove_dynamic_dimension_size
def reshape(x, new_sizes, dimensions=None, name=None):
if dimensions is not None:
x = array_ops.transpose(x, dimensions)
x = array_ops.reshape(x, new_sizes, name=name)
return x
def select(condition, x, y, name=None):
return array_ops.where(condition, x, y, name)
select_and_scatter = gen_xla_ops.xla_select_and_scatter
send = gen_xla_ops.xla_send
def slice(x, start_dims, limit_dims, strides):
spec = [
_slice(start, limit, stride)
for (start, limit, stride) in zip(start_dims, limit_dims, strides)
]
return x[tuple(spec)]
sharding = gen_xla_ops.xla_sharding
@ops.RegisterGradient("XlaSharding")
def _sharding_grad(op, grad):
"""Gradient for XlaSharding op."""
sharding_attr = op.get_attr("sharding")
grad_sharding = gen_xla_ops.xla_sharding(
grad,
sharding=sharding_attr,
unspecified_dims=op.get_attr("unspecified_dims"),
)
# pylint: disable=protected-access
grad_sharding.op._set_attr(
"_XlaSharding", attr_value_pb2.AttrValue(s=sharding_attr)
)
return [grad_sharding]
spmd_full_to_shard_shape = gen_xla_ops.xla_spmd_full_to_shard_shape
spmd_shard_to_full_shape = gen_xla_ops.xla_spmd_shard_to_full_shape
@ops.RegisterGradient("XlaSpmdFullToShardShape")
def _spmd_full_to_shard_shape_grad(op, grad):
s2f = gen_xla_ops.xla_spmd_shard_to_full_shape(
grad,
manual_sharding=op.get_attr("manual_sharding"),
full_shape=op.inputs[0].shape.as_list(),
dim=op.get_attr("dim"),
unspecified_dims=op.get_attr("unspecified_dims"),
)
return [s2f]
@ops.RegisterGradient("XlaSpmdShardToFullShape")
def _spmd_shard_to_full_shape_grad(op, grad):
f2s = gen_xla_ops.xla_spmd_full_to_shard_shape(
grad,
manual_sharding=op.get_attr("manual_sharding"),
dim=op.get_attr("dim"),
unspecified_dims=op.get_attr("unspecified_dims"),
)
return [f2s]
sort = gen_xla_ops.xla_sort
key_value_sort = gen_xla_ops.xla_key_value_sort
variadic_sort = gen_xla_ops.xla_variadic_sort
while_loop = gen_xla_ops.xla_while
dequantize = gen_xla_ops.xla_dequantize
custom_call = gen_xla_ops.xla_custom_call
def custom_call_v2(
call_target_name,
operands,
result_specs,
backend_config=None,
has_side_effect=None,
name=None,
):
"""Emits an HLO `CustomCall` operation with multiple outputs.
See `CustomCall` specification at
https://tensorflow.org/xla/operation_semantics#customcall,
and `mhlo.custom_call` specification at
https://tensorflow.org/mlir/hlo_ops#mhlocustom_call_mlirmhlocustomcallop.
Args:
call_target_name: Name of the user function. The function signature must
conform to version 3 of the API, see
`API_VERSION_STATUS_RETURNING_UNIFIED`. All operands and results assumed
to be in the default layout.
operands: A sequence of tensors with possibly different types.
result_specs: A sequence of tensor specs for all results.
backend_config: A string that encodes a metadata for the backend. Empty
string by default.
has_side_effect: Indicates whether the custom call has side effects. `False`
by default.
name: Optional name of the operation.
Returns:
A tuple of output tensors.
"""
return gen_xla_ops.xla_custom_call_v2(
operands=operands,
call_target_name=call_target_name,
backend_config="" if backend_config is None else backend_config,
has_side_effect=False if has_side_effect is None else has_side_effect,
result_dtypes=tuple(spec.dtype for spec in result_specs),
result_shapes=tuple(spec.shape for spec in result_specs),
name=name,
)
# pylint: disable=g-doc-args
# pylint: disable=g-doc-return-or-yield
def call_module(
args,
*,
version=4,
module,
Tout,
Sout,
platforms=(),
function_list=(),
has_token_input_output=False,
disabled_checks=(),
use_shardy_partitioner=False,
):
"""See documentation for the XlaCallModule op.
https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_ops.cc+xlacallmodule&type=code
"""
res = gen_xla_ops.xla_call_module(
args,
version=version,
module=module,
dim_args_spec=(),
Tout=Tout,
Sout=Sout,
platforms=platforms,
function_list=function_list,
has_token_input_output=has_token_input_output,
disabled_checks=disabled_checks,
use_shardy_partitioner=use_shardy_partitioner,
)
# Since XLACallModule op is stateful, zero return function will return the TF
# op under tf.function. It creates trouble for downstream codes.
# Here we force it return empty tuple to work around it.
# TODO(johnqiangzhang): Figure out a better way to handle control dependency.
if isinstance(res, ops.Operation):
res = ()
return res
def call_module_maximum_supported_version():
"""Maximum version of XlaCallModule op supported.
See versioning details documentation for the XlaCallModule op at:
https://github.com/search?q=repo%3Atensorflow%2Ftensorflow+path%3Axla_call_module+%22int+kVersionMaximumSupported%22&type=code
"""
return 10
# pylint: enable=g-doc-args
# pylint: enable=g-doc-return-or-yield
def call_module_disable_check_platform():
# For use with xla_call_module.disabled_checks.
return "platform"
def gather(
operand,
start_indices,
dimension_numbers,
slice_sizes,
indices_are_sorted=False,
name=None,
):
return gen_xla_ops.xla_gather(
operand,
start_indices,
slice_sizes=slice_sizes,
dimension_numbers=dimension_numbers.SerializeToString(),
indices_are_sorted=indices_are_sorted,
name=name,
)
def scatter(
operand,
scatter_indices,
updates,
update_computation,
dimension_numbers,
indices_are_sorted=False,
name=None,
):
return gen_xla_ops.xla_scatter(
operand,
scatter_indices,
updates,
update_computation=update_computation,
dimension_numbers=dimension_numbers.SerializeToString(),
indices_are_sorted=indices_are_sorted,
name=name,
)
def optimization_barrier(*args):
return gen_xla_ops.xla_optimization_barrier(args)
def reduce_precision(operand, exponent_bits, mantissa_bits):
return gen_xla_ops.xla_reduce_precision(operand, exponent_bits, mantissa_bits)
@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: tensorflow/compiler/tf2xla/tf2xla.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'tensorflow/compiler/tf2xla/tf2xla.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from tensorflow.core.framework import tensor_shape_pb2 as tensorflow_dot_core_dot_framework_dot_tensor__shape__pb2
from tensorflow.core.framework import types_pb2 as tensorflow_dot_core_dot_framework_dot_types__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\'tensorflow/compiler/tf2xla/tf2xla.proto\x12\x11tensorflow.tf2xla\x1a,tensorflow/core/framework/tensor_shape.proto\x1a%tensorflow/core/framework/types.proto\"3\n\x08TensorId\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x14\n\x0coutput_index\x18\x02 \x01(\x03\"\x8e\x01\n\x04\x46\x65\x65\x64\x12\'\n\x02id\x18\x01 \x01(\x0b\x32\x1b.tensorflow.tf2xla.TensorId\x12+\n\x05shape\x18\x02 \x01(\x0b\x32\x1c.tensorflow.TensorShapeProto\x12\x0c\n\x04name\x18\x03 \x01(\t\x12\"\n\x04type\x18\x04 \x01(\x0e\x32\x14.tensorflow.DataType\"\x8f\x01\n\x05\x46\x65tch\x12\'\n\x02id\x18\x01 \x01(\x0b\x32\x1b.tensorflow.tf2xla.TensorId\x12\x0c\n\x04name\x18\x02 \x01(\t\x12+\n\x05shape\x18\x03 \x01(\x0b\x32\x1c.tensorflow.TensorShapeProto\x12\"\n\x04type\x18\x04 \x01(\x0e\x32\x14.tensorflow.DataType\"\x8e\x01\n\x08Variable\x12\x11\n\tnode_name\x18\x01 \x01(\t\x12\x0c\n\x04name\x18\x02 \x01(\t\x12+\n\x05shape\x18\x03 \x01(\x0b\x32\x1c.tensorflow.TensorShapeProto\x12\"\n\x04type\x18\x04 \x01(\x0e\x32\x14.tensorflow.DataType\x12\x10\n\x08readonly\x18\x05 \x01(\x08\"\x87\x01\n\x06\x43onfig\x12%\n\x04\x66\x65\x65\x64\x18\x01 \x03(\x0b\x32\x17.tensorflow.tf2xla.Feed\x12\'\n\x05\x66\x65tch\x18\x02 \x03(\x0b\x32\x18.tensorflow.tf2xla.Fetch\x12-\n\x08variable\x18\x03 \x03(\x0b\x32\x1b.tensorflow.tf2xla.VariableB*\n\x15org.tensorflow.tf2xlaB\x0cTf2XlaProtosP\x01\xf8\x01\x01\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.compiler.tf2xla.tf2xla_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
_globals['DESCRIPTOR']._loaded_options = None
_globals['DESCRIPTOR']._serialized_options = b'\n\025org.tensorflow.tf2xlaB\014Tf2XlaProtosP\001\370\001\001'
_globals['_TENSORID']._serialized_start=147
_globals['_TENSORID']._serialized_end=198
_globals['_FEED']._serialized_start=201
_globals['_FEED']._serialized_end=343
_globals['_FETCH']._serialized_start=346
_globals['_FETCH']._serialized_end=489
_globals['_VARIABLE']._serialized_start=492
_globals['_VARIABLE']._serialized_end=634
_globals['_CONFIG']._serialized_start=637
_globals['_CONFIG']._serialized_end=772
# @@protoc_insertion_point(module_scope)
File diff suppressed because one or more lines are too long
@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: xla/service/metrics.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'xla/service/metrics.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
from google.protobuf import duration_pb2 as google_dot_protobuf_dot_duration__pb2
from google.protobuf import timestamp_pb2 as google_dot_protobuf_dot_timestamp__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x19xla/service/metrics.proto\x12\x03xla\x1a\x19google/protobuf/any.proto\x1a\x1egoogle/protobuf/duration.proto\x1a\x1fgoogle/protobuf/timestamp.proto\",\n\x0eKeyValueMetric\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x03\"\xbc\x01\n\x0bPassMetrics\x12\x11\n\tmodule_id\x18\x01 \x01(\x04\x12\x11\n\tpass_name\x18\x02 \x01(\t\x12\x30\n\rpass_duration\x18\x03 \x01(\x0b\x32\x19.google.protobuf.Duration\x12,\n\x0e\x63ustom_metrics\x18\x04 \x01(\x0b\x32\x14.google.protobuf.Any\x12\'\n\nkv_metrics\x18\x05 \x03(\x0b\x32\x13.xla.KeyValueMetric\"\xbd\x01\n\x07JobInfo\x12\x11\n\x04name\x18\x01 \x01(\tH\x00\x88\x01\x01\x12\x11\n\x04\x63\x65ll\x18\x02 \x01(\tH\x01\x88\x01\x01\x12\x11\n\x04user\x18\x03 \x01(\tH\x02\x88\x01\x01\x12\x10\n\x03uid\x18\x04 \x01(\x03H\x03\x88\x01\x01\x12\x14\n\x07task_id\x18\x05 \x01(\x03H\x04\x88\x01\x01\x12\x15\n\x08task_uid\x18\x06 \x01(\x03H\x05\x88\x01\x01\x42\x07\n\x05_nameB\x07\n\x05_cellB\x07\n\x05_userB\x06\n\x04_uidB\n\n\x08_task_idB\x0b\n\t_task_uid\"\xa2\x03\n\x13\x43ompilationLogEntry\x12-\n\ttimestamp\x18\x01 \x01(\x0b\x32\x1a.google.protobuf.Timestamp\x12\x38\n\x05stage\x18\x02 \x01(\x0e\x32).xla.CompilationLogEntry.CompilationStage\x12+\n\x08\x64uration\x18\x03 \x01(\x0b\x32\x19.google.protobuf.Duration\x12\x12\n\ntask_index\x18\x04 \x01(\x05\x12&\n\x0cpass_metrics\x18\x05 \x03(\x0b\x32\x10.xla.PassMetrics\x12\x12\n\nmodule_ids\x18\x06 \x03(\x04\x12\x1e\n\x08job_info\x18\x07 \x01(\x0b\x32\x0c.xla.JobInfo\x12\x17\n\x0fhlo_module_name\x18\x08 \x01(\t\"l\n\x10\x43ompilationStage\x12\x0f\n\x0bUNSPECIFIED\x10\x00\x12\x0e\n\nEND_TO_END\x10\x01\x12\x0e\n\nHLO_PASSES\x10\x02\x12\x13\n\x0f\x43ODE_GENERATION\x10\x03\x12\x12\n\x0e\x42\x41\x43KEND_PASSES\x10\x04\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'xla.service.metrics_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_KEYVALUEMETRIC']._serialized_start=126
_globals['_KEYVALUEMETRIC']._serialized_end=170
_globals['_PASSMETRICS']._serialized_start=173
_globals['_PASSMETRICS']._serialized_end=361
_globals['_JOBINFO']._serialized_start=364
_globals['_JOBINFO']._serialized_end=553
_globals['_COMPILATIONLOGENTRY']._serialized_start=556
_globals['_COMPILATIONLOGENTRY']._serialized_end=974
_globals['_COMPILATIONLOGENTRY_COMPILATIONSTAGE']._serialized_start=866
_globals['_COMPILATIONLOGENTRY_COMPILATIONSTAGE']._serialized_end=974
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: xla/tsl/protobuf/bfc_memory_map.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'xla/tsl/protobuf/bfc_memory_map.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n%xla/tsl/protobuf/bfc_memory_map.proto\x12\ntensorflow\"\x92\x01\n\x11MemAllocatorStats\x12\x12\n\nnum_allocs\x18\x01 \x01(\x03\x12\x14\n\x0c\x62ytes_in_use\x18\x02 \x01(\x03\x12\x19\n\x11peak_bytes_in_use\x18\x03 \x01(\x03\x12\x1a\n\x12largest_alloc_size\x18\x04 \x01(\x03\x12\x1c\n\x14\x66ragmentation_metric\x18\x05 \x01(\x02\"\xae\x01\n\x08MemChunk\x12\x0f\n\x07\x61\x64\x64ress\x18\x01 \x01(\x04\x12\x0c\n\x04size\x18\x02 \x01(\x03\x12\x16\n\x0erequested_size\x18\x03 \x01(\x03\x12\x0b\n\x03\x62in\x18\x04 \x01(\x05\x12\x0f\n\x07op_name\x18\x05 \x01(\t\x12\x16\n\x0e\x66reed_at_count\x18\x06 \x01(\x04\x12\x14\n\x0c\x61\x63tion_count\x18\x07 \x01(\x04\x12\x0e\n\x06in_use\x18\x08 \x01(\x08\x12\x0f\n\x07step_id\x18\t \x01(\x04\"\x8b\x01\n\nBinSummary\x12\x0b\n\x03\x62in\x18\x01 \x01(\x05\x12\x1a\n\x12total_bytes_in_use\x18\x02 \x01(\x03\x12\x1a\n\x12total_bytes_in_bin\x18\x03 \x01(\x03\x12\x1b\n\x13total_chunks_in_use\x18\x04 \x01(\x03\x12\x1b\n\x13total_chunks_in_bin\x18\x05 \x01(\x03\".\n\x08SnapShot\x12\x14\n\x0c\x61\x63tion_count\x18\x01 \x01(\x04\x12\x0c\n\x04size\x18\x02 \x01(\x03\"\xcd\x01\n\nMemoryDump\x12\x16\n\x0e\x61llocator_name\x18\x01 \x01(\t\x12+\n\x0b\x62in_summary\x18\x02 \x03(\x0b\x32\x16.tensorflow.BinSummary\x12#\n\x05\x63hunk\x18\x03 \x03(\x0b\x32\x14.tensorflow.MemChunk\x12\'\n\tsnap_shot\x18\x04 \x03(\x0b\x32\x14.tensorflow.SnapShot\x12,\n\x05stats\x18\x05 \x01(\x0b\x32\x1d.tensorflow.MemAllocatorStatsB@Z>github.com/google/tsl/tsl/go/protobuf/for_core_protos_go_protob\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'xla.tsl.protobuf.bfc_memory_map_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
_globals['DESCRIPTOR']._loaded_options = None
_globals['DESCRIPTOR']._serialized_options = b'Z>github.com/google/tsl/tsl/go/protobuf/for_core_protos_go_proto'
_globals['_MEMALLOCATORSTATS']._serialized_start=54
_globals['_MEMALLOCATORSTATS']._serialized_end=200
_globals['_MEMCHUNK']._serialized_start=203
_globals['_MEMCHUNK']._serialized_end=377
_globals['_BINSUMMARY']._serialized_start=380
_globals['_BINSUMMARY']._serialized_end=519
_globals['_SNAPSHOT']._serialized_start=521
_globals['_SNAPSHOT']._serialized_end=567
_globals['_MEMORYDUMP']._serialized_start=570
_globals['_MEMORYDUMP']._serialized_end=775
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,39 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: xla/tsl/protobuf/coordination_config.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'xla/tsl/protobuf/coordination_config.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n*xla/tsl/protobuf/coordination_config.proto\x12\ntensorflow\"1\n\x0e\x43oordinatedJob\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x11\n\tnum_tasks\x18\x02 \x01(\x05\"\xf7\x03\n\x19\x43oordinationServiceConfig\x12\x14\n\x0cservice_type\x18\x01 \x01(\t\x12\x16\n\x0eservice_leader\x18\x02 \x01(\t\x12\x1b\n\x13\x65nable_health_check\x18\x03 \x01(\x08\x12&\n\x1e\x63luster_register_timeout_in_ms\x18\x04 \x01(\x03\x12%\n\x1d\x63luster_register_with_barrier\x18\x0e \x01(\x08\x12\x1f\n\x17heartbeat_timeout_in_ms\x18\x05 \x01(\x03\x12\x38\n\x14\x63oordinated_job_list\x18\n \x03(\x0b\x32\x1a.tensorflow.CoordinatedJob\x12&\n\x1eshutdown_barrier_timeout_in_ms\x18\x07 \x01(\x03\x12*\n\"agent_destruction_without_shutdown\x18\x08 \x01(\x08\x12\x18\n\x10recoverable_jobs\x18\t \x03(\t\x12*\n\"allow_new_incarnation_to_reconnect\x18\x0b \x01(\x08\x12\x15\n\rforce_disable\x18\x0c \x01(\x08\x12.\n&poll_for_error_from_service_at_startup\x18\r \x01(\x08J\x04\x08\x06\x10\x07\x42WZUgithub.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_protob\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'xla.tsl.protobuf.coordination_config_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
_globals['DESCRIPTOR']._loaded_options = None
_globals['DESCRIPTOR']._serialized_options = b'ZUgithub.com/tensorflow/tensorflow/tensorflow/go/core/protobuf/for_core_protos_go_proto'
_globals['_COORDINATEDJOB']._serialized_start=58
_globals['_COORDINATEDJOB']._serialized_end=107
_globals['_COORDINATIONSERVICECONFIG']._serialized_start=110
_globals['_COORDINATIONSERVICECONFIG']._serialized_end=613
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: xla/tsl/protobuf/distributed_runtime_payloads.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'xla/tsl/protobuf/distributed_runtime_payloads.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n3xla/tsl/protobuf/distributed_runtime_payloads.proto\x12\x1etensorflow.distributed_runtime\"\x9d\x01\n\x14GrpcPayloadContainer\x12T\n\x08payloads\x18\x01 \x03(\x0b\x32\x42.tensorflow.distributed_runtime.GrpcPayloadContainer.PayloadsEntry\x1a/\n\rPayloadsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c:\x02\x38\x01\"\x12\n\x10GrpcPayloadsLost\"\x19\n\x17WorkerPossiblyRestartedBAZ<github.com/tsl/tsl/go/core/protobuf/for_core_protos_go_proto\xf8\x01\x01\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'xla.tsl.protobuf.distributed_runtime_payloads_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
_globals['DESCRIPTOR']._loaded_options = None
_globals['DESCRIPTOR']._serialized_options = b'Z<github.com/tsl/tsl/go/core/protobuf/for_core_protos_go_proto\370\001\001'
_globals['_GRPCPAYLOADCONTAINER_PAYLOADSENTRY']._loaded_options = None
_globals['_GRPCPAYLOADCONTAINER_PAYLOADSENTRY']._serialized_options = b'8\001'
_globals['_GRPCPAYLOADCONTAINER']._serialized_start=88
_globals['_GRPCPAYLOADCONTAINER']._serialized_end=245
_globals['_GRPCPAYLOADCONTAINER_PAYLOADSENTRY']._serialized_start=198
_globals['_GRPCPAYLOADCONTAINER_PAYLOADSENTRY']._serialized_end=245
_globals['_GRPCPAYLOADSLOST']._serialized_start=247
_globals['_GRPCPAYLOADSLOST']._serialized_end=265
_globals['_WORKERPOSSIBLYRESTARTED']._serialized_start=267
_globals['_WORKERPOSSIBLYRESTARTED']._serialized_end=292
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: xla/tsl/protobuf/error_codes.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'xla/tsl/protobuf/error_codes.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\"xla/tsl/protobuf/error_codes.proto\x12\x10tensorflow.error*\x84\x03\n\x04\x43ode\x12\x06\n\x02OK\x10\x00\x12\r\n\tCANCELLED\x10\x01\x12\x0b\n\x07UNKNOWN\x10\x02\x12\x14\n\x10INVALID_ARGUMENT\x10\x03\x12\x15\n\x11\x44\x45\x41\x44LINE_EXCEEDED\x10\x04\x12\r\n\tNOT_FOUND\x10\x05\x12\x12\n\x0e\x41LREADY_EXISTS\x10\x06\x12\x15\n\x11PERMISSION_DENIED\x10\x07\x12\x13\n\x0fUNAUTHENTICATED\x10\x10\x12\x16\n\x12RESOURCE_EXHAUSTED\x10\x08\x12\x17\n\x13\x46\x41ILED_PRECONDITION\x10\t\x12\x0b\n\x07\x41\x42ORTED\x10\n\x12\x10\n\x0cOUT_OF_RANGE\x10\x0b\x12\x11\n\rUNIMPLEMENTED\x10\x0c\x12\x0c\n\x08INTERNAL\x10\r\x12\x0f\n\x0bUNAVAILABLE\x10\x0e\x12\r\n\tDATA_LOSS\x10\x0f\x12K\nGDO_NOT_USE_RESERVED_FOR_FUTURE_EXPANSION_USE_DEFAULT_IN_SWITCH_INSTEAD_\x10\x14\x42q\n\x18org.tensorflow.frameworkB\x10\x45rrorCodesProtosP\x01Z>github.com/google/tsl/tsl/go/protobuf/for_core_protos_go_proto\xf8\x01\x01\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'xla.tsl.protobuf.error_codes_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
_globals['DESCRIPTOR']._loaded_options = None
_globals['DESCRIPTOR']._serialized_options = b'\n\030org.tensorflow.frameworkB\020ErrorCodesProtosP\001Z>github.com/google/tsl/tsl/go/protobuf/for_core_protos_go_proto\370\001\001'
_globals['_CODE']._serialized_start=57
_globals['_CODE']._serialized_end=445
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,41 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: xla/tsl/protobuf/histogram.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'xla/tsl/protobuf/histogram.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n xla/tsl/protobuf/histogram.proto\x12\ntensorflow\"\x87\x01\n\x0eHistogramProto\x12\x0b\n\x03min\x18\x01 \x01(\x01\x12\x0b\n\x03max\x18\x02 \x01(\x01\x12\x0b\n\x03num\x18\x03 \x01(\x01\x12\x0b\n\x03sum\x18\x04 \x01(\x01\x12\x13\n\x0bsum_squares\x18\x05 \x01(\x01\x12\x18\n\x0c\x62ucket_limit\x18\x06 \x03(\x01\x42\x02\x10\x01\x12\x12\n\x06\x62ucket\x18\x07 \x03(\x01\x42\x02\x10\x01\x42\\\n\x18org.tensorflow.frameworkP\x01Z;github.com/google/tsl/tsl/go/core/protobuf/summary_go_proto\xf8\x01\x01\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'xla.tsl.protobuf.histogram_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
_globals['DESCRIPTOR']._loaded_options = None
_globals['DESCRIPTOR']._serialized_options = b'\n\030org.tensorflow.frameworkP\001Z;github.com/google/tsl/tsl/go/core/protobuf/summary_go_proto\370\001\001'
_globals['_HISTOGRAMPROTO'].fields_by_name['bucket_limit']._loaded_options = None
_globals['_HISTOGRAMPROTO'].fields_by_name['bucket_limit']._serialized_options = b'\020\001'
_globals['_HISTOGRAMPROTO'].fields_by_name['bucket']._loaded_options = None
_globals['_HISTOGRAMPROTO'].fields_by_name['bucket']._serialized_options = b'\020\001'
_globals['_HISTOGRAMPROTO']._serialized_start=49
_globals['_HISTOGRAMPROTO']._serialized_end=184
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: xla/tsl/protobuf/rpc_options.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'xla/tsl/protobuf/rpc_options.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\"xla/tsl/protobuf/rpc_options.proto\x12\ntensorflow\"\xd5\x01\n\nRPCOptions\x12$\n\x1cuse_rpc_for_inprocess_master\x18\x01 \x01(\x08\x12\x1d\n\x15\x63ompression_algorithm\x18\x02 \x01(\t\x12\x19\n\x11\x63ompression_level\x18\x03 \x01(\x05\x12\x1a\n\x12\x63\x61\x63he_rpc_response\x18\x04 \x01(\x08\x12*\n\"disable_session_connection_sharing\x18\x05 \x01(\x08\x12\x1f\n\x17num_channels_per_target\x18\x06 \x01(\x05\x42@Z>github.com/google/tsl/tsl/go/protobuf/for_core_protos_go_protob\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'xla.tsl.protobuf.rpc_options_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
_globals['DESCRIPTOR']._loaded_options = None
_globals['DESCRIPTOR']._serialized_options = b'Z>github.com/google/tsl/tsl/go/protobuf/for_core_protos_go_proto'
_globals['_RPCOPTIONS']._serialized_start=51
_globals['_RPCOPTIONS']._serialized_end=264
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: xla/tsl/protobuf/status.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'xla/tsl/protobuf/status.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from tensorflow.compiler.xla.tsl.protobuf import error_codes_pb2 as xla_dot_tsl_dot_protobuf_dot_error__codes__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1dxla/tsl/protobuf/status.proto\x12\ntensorflow\x1a\"xla/tsl/protobuf/error_codes.proto\"\xab\x01\n\x0bStatusProto\x12$\n\x04\x63ode\x18\x01 \x01(\x0e\x32\x16.tensorflow.error.Code\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x35\n\x07payload\x18\x03 \x03(\x0b\x32$.tensorflow.StatusProto.PayloadEntry\x1a.\n\x0cPayloadEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x0c:\x02\x38\x01\x42_\n\x18org.tensorflow.frameworkP\x01Z>github.com/google/tsl/tsl/go/protobuf/for_core_protos_go_proto\xf8\x01\x01\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'xla.tsl.protobuf.status_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
_globals['DESCRIPTOR']._loaded_options = None
_globals['DESCRIPTOR']._serialized_options = b'\n\030org.tensorflow.frameworkP\001Z>github.com/google/tsl/tsl/go/protobuf/for_core_protos_go_proto\370\001\001'
_globals['_STATUSPROTO_PAYLOADENTRY']._loaded_options = None
_globals['_STATUSPROTO_PAYLOADENTRY']._serialized_options = b'8\001'
_globals['_STATUSPROTO']._serialized_start=82
_globals['_STATUSPROTO']._serialized_end=253
_globals['_STATUSPROTO_PAYLOADENTRY']._serialized_start=207
_globals['_STATUSPROTO_PAYLOADENTRY']._serialized_end=253
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,79 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: xla/tsl/protobuf/test_log.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'xla/tsl/protobuf/test_log.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from google.protobuf import any_pb2 as google_dot_protobuf_dot_any__pb2
from google.protobuf import wrappers_pb2 as google_dot_protobuf_dot_wrappers__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1fxla/tsl/protobuf/test_log.proto\x12\ntensorflow\x1a\x19google/protobuf/any.proto\x1a\x1egoogle/protobuf/wrappers.proto\"D\n\nEntryValue\x12\x16\n\x0c\x64ouble_value\x18\x01 \x01(\x01H\x00\x12\x16\n\x0cstring_value\x18\x02 \x01(\tH\x00\x42\x06\n\x04kind\"\x8c\x01\n\x0bMetricEntry\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x01\x12/\n\tmin_value\x18\x03 \x01(\x0b\x32\x1c.google.protobuf.DoubleValue\x12/\n\tmax_value\x18\x04 \x01(\x0b\x32\x1c.google.protobuf.DoubleValue\"\x8f\x02\n\x0e\x42\x65nchmarkEntry\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05iters\x18\x02 \x01(\x03\x12\x10\n\x08\x63pu_time\x18\x03 \x01(\x01\x12\x11\n\twall_time\x18\x04 \x01(\x01\x12\x12\n\nthroughput\x18\x05 \x01(\x01\x12\x36\n\x06\x65xtras\x18\x06 \x03(\x0b\x32&.tensorflow.BenchmarkEntry.ExtrasEntry\x12(\n\x07metrics\x18\x07 \x03(\x0b\x32\x17.tensorflow.MetricEntry\x1a\x45\n\x0b\x45xtrasEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12%\n\x05value\x18\x02 \x01(\x0b\x32\x16.tensorflow.EntryValue:\x02\x38\x01\"=\n\x10\x42\x65nchmarkEntries\x12)\n\x05\x65ntry\x18\x01 \x03(\x0b\x32\x1a.tensorflow.BenchmarkEntry\"B\n\x12\x42uildConfiguration\x12\x0c\n\x04mode\x18\x01 \x01(\t\x12\x10\n\x08\x63\x63_flags\x18\x02 \x03(\t\x12\x0c\n\x04opts\x18\x03 \x03(\t\"f\n\x08\x43ommitId\x12\x14\n\nchangelist\x18\x01 \x01(\x03H\x00\x12\x0e\n\x04hash\x18\x02 \x01(\tH\x00\x12\x10\n\x08snapshot\x18\x03 \x01(\t\x12\x1a\n\x12pending_changelist\x18\x04 \x01(\x03\x42\x06\n\x04kind\"\xde\x01\n\x07\x43PUInfo\x12\x11\n\tnum_cores\x18\x01 \x01(\x03\x12\x19\n\x11num_cores_allowed\x18\x02 \x01(\x03\x12\x13\n\x0bmhz_per_cpu\x18\x03 \x01(\x01\x12\x10\n\x08\x63pu_info\x18\x04 \x01(\t\x12\x14\n\x0c\x63pu_governor\x18\x05 \x01(\t\x12\x36\n\ncache_size\x18\x06 \x03(\x0b\x32\".tensorflow.CPUInfo.CacheSizeEntry\x1a\x30\n\x0e\x43\x61\x63heSizeEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x03:\x02\x38\x01\".\n\nMemoryInfo\x12\r\n\x05total\x18\x01 \x01(\x03\x12\x11\n\tavailable\x18\x02 \x01(\x03\"6\n\x07GPUInfo\x12\r\n\x05model\x18\x01 \x01(\t\x12\x0c\n\x04uuid\x18\x02 \x01(\t\x12\x0e\n\x06\x62us_id\x18\x03 \x01(\t\"p\n\x0cPlatformInfo\x12\x0c\n\x04\x62its\x18\x01 \x01(\t\x12\x0f\n\x07linkage\x18\x02 \x01(\t\x12\x0f\n\x07machine\x18\x03 \x01(\t\x12\x0f\n\x07release\x18\x04 \x01(\t\x12\x0e\n\x06system\x18\x05 \x01(\t\x12\x0f\n\x07version\x18\x06 \x01(\t\"e\n\x13\x41vailableDeviceInfo\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04type\x18\x02 \x01(\t\x12\x14\n\x0cmemory_limit\x18\x03 \x01(\x03\x12\x1c\n\x14physical_description\x18\x04 \x01(\t\"\xb3\x02\n\x14MachineConfiguration\x12\x10\n\x08hostname\x18\x01 \x01(\t\x12\x19\n\x11serial_identifier\x18\x07 \x01(\t\x12/\n\rplatform_info\x18\x02 \x01(\x0b\x32\x18.tensorflow.PlatformInfo\x12%\n\x08\x63pu_info\x18\x03 \x01(\x0b\x32\x13.tensorflow.CPUInfo\x12)\n\x0b\x64\x65vice_info\x18\x04 \x03(\x0b\x32\x14.google.protobuf.Any\x12>\n\x15\x61vailable_device_info\x18\x05 \x03(\x0b\x32\x1f.tensorflow.AvailableDeviceInfo\x12+\n\x0bmemory_info\x18\x06 \x01(\x0b\x32\x16.tensorflow.MemoryInfo\"\x91\x01\n\x10RunConfiguration\x12\x10\n\x08\x61rgument\x18\x01 \x03(\t\x12;\n\x08\x65nv_vars\x18\x02 \x03(\x0b\x32).tensorflow.RunConfiguration.EnvVarsEntry\x1a.\n\x0c\x45nvVarsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xd0\x04\n\x0bTestResults\x12\x0e\n\x06target\x18\x01 \x01(\t\x12-\n\x07\x65ntries\x18\x02 \x01(\x0b\x32\x1c.tensorflow.BenchmarkEntries\x12;\n\x13\x62uild_configuration\x18\x03 \x01(\x0b\x32\x1e.tensorflow.BuildConfiguration\x12\'\n\tcommit_id\x18\x04 \x01(\x0b\x32\x14.tensorflow.CommitId\x12\x12\n\nstart_time\x18\x05 \x01(\x03\x12\x10\n\x08run_time\x18\x06 \x01(\x01\x12?\n\x15machine_configuration\x18\x07 \x01(\x0b\x32 .tensorflow.MachineConfiguration\x12\x37\n\x11run_configuration\x18\x08 \x01(\x0b\x32\x1c.tensorflow.RunConfiguration\x12\x0c\n\x04name\x18\t \x01(\t\x12=\n\x0e\x62\x65nchmark_type\x18\n \x01(\x0e\x32%.tensorflow.TestResults.BenchmarkType\x12\x10\n\x08run_mode\x18\x0b \x01(\t\x12\x12\n\ntf_version\x18\x0c \x01(\t\"\x88\x01\n\rBenchmarkType\x12\x0b\n\x07UNKNOWN\x10\x00\x12\x16\n\x12\x43PP_MICROBENCHMARK\x10\x01\x12\x14\n\x10PYTHON_BENCHMARK\x10\x02\x12\x15\n\x11\x41NDROID_BENCHMARK\x10\x03\x12\x12\n\x0e\x45\x44GE_BENCHMARK\x10\x04\x12\x11\n\rIOS_BENCHMARK\x10\x05\x42\x31\n\x1borg.tensorflow.util.testlogB\rTestLogProtosP\x01\xf8\x01\x01\x62\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'xla.tsl.protobuf.test_log_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
_globals['DESCRIPTOR']._loaded_options = None
_globals['DESCRIPTOR']._serialized_options = b'\n\033org.tensorflow.util.testlogB\rTestLogProtosP\001\370\001\001'
_globals['_BENCHMARKENTRY_EXTRASENTRY']._loaded_options = None
_globals['_BENCHMARKENTRY_EXTRASENTRY']._serialized_options = b'8\001'
_globals['_CPUINFO_CACHESIZEENTRY']._loaded_options = None
_globals['_CPUINFO_CACHESIZEENTRY']._serialized_options = b'8\001'
_globals['_RUNCONFIGURATION_ENVVARSENTRY']._loaded_options = None
_globals['_RUNCONFIGURATION_ENVVARSENTRY']._serialized_options = b'8\001'
_globals['_ENTRYVALUE']._serialized_start=106
_globals['_ENTRYVALUE']._serialized_end=174
_globals['_METRICENTRY']._serialized_start=177
_globals['_METRICENTRY']._serialized_end=317
_globals['_BENCHMARKENTRY']._serialized_start=320
_globals['_BENCHMARKENTRY']._serialized_end=591
_globals['_BENCHMARKENTRY_EXTRASENTRY']._serialized_start=522
_globals['_BENCHMARKENTRY_EXTRASENTRY']._serialized_end=591
_globals['_BENCHMARKENTRIES']._serialized_start=593
_globals['_BENCHMARKENTRIES']._serialized_end=654
_globals['_BUILDCONFIGURATION']._serialized_start=656
_globals['_BUILDCONFIGURATION']._serialized_end=722
_globals['_COMMITID']._serialized_start=724
_globals['_COMMITID']._serialized_end=826
_globals['_CPUINFO']._serialized_start=829
_globals['_CPUINFO']._serialized_end=1051
_globals['_CPUINFO_CACHESIZEENTRY']._serialized_start=1003
_globals['_CPUINFO_CACHESIZEENTRY']._serialized_end=1051
_globals['_MEMORYINFO']._serialized_start=1053
_globals['_MEMORYINFO']._serialized_end=1099
_globals['_GPUINFO']._serialized_start=1101
_globals['_GPUINFO']._serialized_end=1155
_globals['_PLATFORMINFO']._serialized_start=1157
_globals['_PLATFORMINFO']._serialized_end=1269
_globals['_AVAILABLEDEVICEINFO']._serialized_start=1271
_globals['_AVAILABLEDEVICEINFO']._serialized_end=1372
_globals['_MACHINECONFIGURATION']._serialized_start=1375
_globals['_MACHINECONFIGURATION']._serialized_end=1682
_globals['_RUNCONFIGURATION']._serialized_start=1685
_globals['_RUNCONFIGURATION']._serialized_end=1830
_globals['_RUNCONFIGURATION_ENVVARSENTRY']._serialized_start=1784
_globals['_RUNCONFIGURATION_ENVVARSENTRY']._serialized_end=1830
_globals['_TESTRESULTS']._serialized_start=1833
_globals['_TESTRESULTS']._serialized_end=2425
_globals['_TESTRESULTS_BENCHMARKTYPE']._serialized_start=2289
_globals['_TESTRESULTS_BENCHMARKTYPE']._serialized_end=2425
# @@protoc_insertion_point(module_scope)
File diff suppressed because one or more lines are too long