Import tensorflow

This commit is contained in:
2026-02-15 21:45:42 -08:00
parent f3e8b90764
commit c530630153
20524 changed files with 9017694 additions and 25 deletions
@@ -0,0 +1,35 @@
from sys import modules
from types import ModuleType
def __update_globals(new_import_path, pywrap_m):
all_names = pywrap_m.__all__ if hasattr(pywrap_m, '__all__') else dir(
pywrap_m)
modules[new_import_path] = pywrap_m
for name in all_names:
sub_pywrap = getattr(pywrap_m, name)
if isinstance(sub_pywrap, ModuleType):
sub_name = sub_pywrap.__name__[len(pywrap_m.__name__):]
__update_globals(new_import_path + sub_name, sub_pywrap)
def __try_import():
imports_paths = ["litert.python.pywrap_genai_ops", "third_party.tensorflow.lite.python.pywrap_genai_ops", "tensorflow.pywrap_genai_ops", "tensorflow.python.pywrap_genai_ops"] # template_val
exceptions = []
last_exception = None
for import_path in imports_paths:
try:
pywrap_m = __import__(import_path, fromlist=["*"])
__update_globals(__name__, pywrap_m)
return
except ImportError as e:
exceptions.append(str(e))
last_exception = e
pass
raise RuntimeError(f"""
Could not import original test/binary location, import paths tried: {imports_paths}.
Previous exceptions: {exceptions}""", last_exception)
__try_import()
@@ -0,0 +1,423 @@
"""Python wrappers around TensorFlow ops.
This file is MACHINE GENERATED! Do not edit.
"""
import collections
from tensorflow.python import pywrap_tfe as pywrap_tfe
from tensorflow.python.eager import context as _context
from tensorflow.python.eager import core as _core
from tensorflow.python.eager import execute as _execute
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.security.fuzzing.py import annotation_types as _atypes
from tensorflow.python.framework import op_def_registry as _op_def_registry
from tensorflow.python.framework import ops as _ops
from tensorflow.python.framework import op_def_library as _op_def_library
from tensorflow.python.util.deprecation import deprecated_endpoints
from tensorflow.python.util import dispatch as _dispatch
from tensorflow.python.util.tf_export import tf_export
from typing import TypeVar, List, Any
from typing_extensions import Annotated
TV_AudioMicrofrontend_out_type = TypeVar("TV_AudioMicrofrontend_out_type", "_atypes.Float32", "_atypes.UInt16")
@_dispatch.add_fallback_dispatch_list
@_dispatch.add_type_based_api_dispatcher
@tf_export('audio_microfrontend')
def audio_microfrontend(audio: Annotated[Any, _atypes.Int16], sample_rate:int=16000, window_size:int=25, window_step:int=10, num_channels:int=32, upper_band_limit:float=7500, lower_band_limit:float=125, smoothing_bits:int=10, even_smoothing:float=0.025, odd_smoothing:float=0.06, min_signal_remaining:float=0.05, enable_pcan:bool=False, pcan_strength:float=0.95, pcan_offset:float=80, gain_bits:int=21, enable_log:bool=True, scale_shift:int=6, left_context:int=0, right_context:int=0, frame_stride:int=1, zero_padding:bool=False, out_scale:int=1, out_type:TV_AudioMicrofrontend_out_type=_dtypes.uint16, name=None) -> Annotated[Any, TV_AudioMicrofrontend_out_type]:
r"""Audio Microfrontend Op.
This Op converts a sequence of audio data into one or more
feature vectors containing filterbanks of the input. The
conversion process uses a lightweight library to perform:
1. A slicing window function
2. Short-time FFTs
3. Filterbank calculations
4. Noise reduction
5. PCAN Auto Gain Control
6. Logarithmic scaling
Arguments
audio: 1D Tensor, int16 audio data in temporal ordering.
sample_rate: Integer, the sample rate of the audio in Hz.
window_size: Integer, length of desired time frames in ms.
window_step: Integer, length of step size for the next frame in ms.
num_channels: Integer, the number of filterbank channels to use.
upper_band_limit: Float, the highest frequency included in the filterbanks.
lower_band_limit: Float, the lowest frequency included in the filterbanks.
smoothing_bits: Int, scale up signal by 2^(smoothing_bits) before reduction.
even_smoothing: Float, smoothing coefficient for even-numbered channels.
odd_smoothing: Float, smoothing coefficient for odd-numbered channels.
min_signal_remaining: Float, fraction of signal to preserve in smoothing.
enable_pcan: Bool, enable PCAN auto gain control.
pcan_strength: Float, gain normalization exponent.
pcan_offset: Float, positive value added in the normalization denominator.
gain_bits: Int, number of fractional bits in the gain.
enable_log: Bool, enable logarithmic scaling of filterbanks.
scale_shift: Integer, scale filterbanks by 2^(scale_shift).
left_context: Integer, number of preceding frames to attach to each frame.
right_context: Integer, number of preceding frames to attach to each frame.
frame_stride: Integer, M frames to skip over, where output[n] = frame[n*M].
zero_padding: Bool, if left/right context is out-of-bounds, attach frame of
zeroes. Otherwise, frame[0] or frame[size-1] will be copied.
out_scale: Integer, divide all filterbanks by this number.
out_type: DType, type of the output Tensor, defaults to UINT16.
Returns
filterbanks: 2D Tensor, each row is a time frame, each column is a channel.
Args:
audio: A `Tensor` of type `int16`.
sample_rate: An optional `int`. Defaults to `16000`.
window_size: An optional `int`. Defaults to `25`.
window_step: An optional `int`. Defaults to `10`.
num_channels: An optional `int`. Defaults to `32`.
upper_band_limit: An optional `float`. Defaults to `7500`.
lower_band_limit: An optional `float`. Defaults to `125`.
smoothing_bits: An optional `int`. Defaults to `10`.
even_smoothing: An optional `float`. Defaults to `0.025`.
odd_smoothing: An optional `float`. Defaults to `0.06`.
min_signal_remaining: An optional `float`. Defaults to `0.05`.
enable_pcan: An optional `bool`. Defaults to `False`.
pcan_strength: An optional `float`. Defaults to `0.95`.
pcan_offset: An optional `float`. Defaults to `80`.
gain_bits: An optional `int`. Defaults to `21`.
enable_log: An optional `bool`. Defaults to `True`.
scale_shift: An optional `int`. Defaults to `6`.
left_context: An optional `int`. Defaults to `0`.
right_context: An optional `int`. Defaults to `0`.
frame_stride: An optional `int`. Defaults to `1`.
zero_padding: An optional `bool`. Defaults to `False`.
out_scale: An optional `int`. Defaults to `1`.
out_type: An optional `tf.DType` from: `tf.uint16, tf.float32`. Defaults to `tf.uint16`.
name: A name for the operation (optional).
Returns:
A `Tensor` of type `out_type`.
"""
_ctx = _context._context or _context.context()
tld = _ctx._thread_local_data
if tld.is_eager:
try:
_result = pywrap_tfe.TFE_Py_FastPathExecute(
_ctx, "AudioMicrofrontend", name, audio, "sample_rate", sample_rate,
"window_size", window_size, "window_step", window_step,
"num_channels", num_channels, "upper_band_limit", upper_band_limit,
"lower_band_limit", lower_band_limit, "smoothing_bits",
smoothing_bits, "even_smoothing", even_smoothing, "odd_smoothing",
odd_smoothing, "min_signal_remaining", min_signal_remaining,
"enable_pcan", enable_pcan, "pcan_strength", pcan_strength,
"pcan_offset", pcan_offset, "gain_bits", gain_bits, "enable_log",
enable_log, "scale_shift", scale_shift, "left_context", left_context,
"right_context", right_context, "frame_stride", frame_stride,
"zero_padding", zero_padding, "out_scale", out_scale, "out_type",
out_type)
return _result
except _core._NotOkStatusException as e:
_ops.raise_from_not_ok_status(e, name)
except _core._FallbackException:
pass
try:
_result = _dispatcher_for_audio_microfrontend(
(audio, sample_rate, window_size, window_step, num_channels,
upper_band_limit, lower_band_limit, smoothing_bits, even_smoothing,
odd_smoothing, min_signal_remaining, enable_pcan, pcan_strength,
pcan_offset, gain_bits, enable_log, scale_shift, left_context,
right_context, frame_stride, zero_padding, out_scale, out_type,
name,), None)
if _result is not NotImplemented:
return _result
return audio_microfrontend_eager_fallback(
audio, sample_rate=sample_rate, window_size=window_size,
window_step=window_step, num_channels=num_channels,
upper_band_limit=upper_band_limit,
lower_band_limit=lower_band_limit, smoothing_bits=smoothing_bits,
even_smoothing=even_smoothing, odd_smoothing=odd_smoothing,
min_signal_remaining=min_signal_remaining, enable_pcan=enable_pcan,
pcan_strength=pcan_strength, pcan_offset=pcan_offset,
gain_bits=gain_bits, enable_log=enable_log, scale_shift=scale_shift,
left_context=left_context, right_context=right_context,
frame_stride=frame_stride, zero_padding=zero_padding,
out_scale=out_scale, out_type=out_type, name=name, ctx=_ctx)
except _core._SymbolicException:
pass # Add nodes to the TensorFlow graph.
except (TypeError, ValueError):
_result = _dispatch.dispatch(
audio_microfrontend, (), dict(audio=audio,
sample_rate=sample_rate,
window_size=window_size,
window_step=window_step,
num_channels=num_channels,
upper_band_limit=upper_band_limit,
lower_band_limit=lower_band_limit,
smoothing_bits=smoothing_bits,
even_smoothing=even_smoothing,
odd_smoothing=odd_smoothing,
min_signal_remaining=min_signal_remaining,
enable_pcan=enable_pcan,
pcan_strength=pcan_strength,
pcan_offset=pcan_offset,
gain_bits=gain_bits,
enable_log=enable_log,
scale_shift=scale_shift,
left_context=left_context,
right_context=right_context,
frame_stride=frame_stride,
zero_padding=zero_padding,
out_scale=out_scale,
out_type=out_type, name=name)
)
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
return _result
raise
else:
_result = _dispatcher_for_audio_microfrontend(
(audio, sample_rate, window_size, window_step, num_channels,
upper_band_limit, lower_band_limit, smoothing_bits, even_smoothing,
odd_smoothing, min_signal_remaining, enable_pcan, pcan_strength,
pcan_offset, gain_bits, enable_log, scale_shift, left_context,
right_context, frame_stride, zero_padding, out_scale, out_type,
name,), None)
if _result is not NotImplemented:
return _result
# Add nodes to the TensorFlow graph.
if sample_rate is None:
sample_rate = 16000
sample_rate = _execute.make_int(sample_rate, "sample_rate")
if window_size is None:
window_size = 25
window_size = _execute.make_int(window_size, "window_size")
if window_step is None:
window_step = 10
window_step = _execute.make_int(window_step, "window_step")
if num_channels is None:
num_channels = 32
num_channels = _execute.make_int(num_channels, "num_channels")
if upper_band_limit is None:
upper_band_limit = 7500
upper_band_limit = _execute.make_float(upper_band_limit, "upper_band_limit")
if lower_band_limit is None:
lower_band_limit = 125
lower_band_limit = _execute.make_float(lower_band_limit, "lower_band_limit")
if smoothing_bits is None:
smoothing_bits = 10
smoothing_bits = _execute.make_int(smoothing_bits, "smoothing_bits")
if even_smoothing is None:
even_smoothing = 0.025
even_smoothing = _execute.make_float(even_smoothing, "even_smoothing")
if odd_smoothing is None:
odd_smoothing = 0.06
odd_smoothing = _execute.make_float(odd_smoothing, "odd_smoothing")
if min_signal_remaining is None:
min_signal_remaining = 0.05
min_signal_remaining = _execute.make_float(min_signal_remaining, "min_signal_remaining")
if enable_pcan is None:
enable_pcan = False
enable_pcan = _execute.make_bool(enable_pcan, "enable_pcan")
if pcan_strength is None:
pcan_strength = 0.95
pcan_strength = _execute.make_float(pcan_strength, "pcan_strength")
if pcan_offset is None:
pcan_offset = 80
pcan_offset = _execute.make_float(pcan_offset, "pcan_offset")
if gain_bits is None:
gain_bits = 21
gain_bits = _execute.make_int(gain_bits, "gain_bits")
if enable_log is None:
enable_log = True
enable_log = _execute.make_bool(enable_log, "enable_log")
if scale_shift is None:
scale_shift = 6
scale_shift = _execute.make_int(scale_shift, "scale_shift")
if left_context is None:
left_context = 0
left_context = _execute.make_int(left_context, "left_context")
if right_context is None:
right_context = 0
right_context = _execute.make_int(right_context, "right_context")
if frame_stride is None:
frame_stride = 1
frame_stride = _execute.make_int(frame_stride, "frame_stride")
if zero_padding is None:
zero_padding = False
zero_padding = _execute.make_bool(zero_padding, "zero_padding")
if out_scale is None:
out_scale = 1
out_scale = _execute.make_int(out_scale, "out_scale")
if out_type is None:
out_type = _dtypes.uint16
out_type = _execute.make_type(out_type, "out_type")
try:
_, _, _op, _outputs = _op_def_library._apply_op_helper(
"AudioMicrofrontend", audio=audio, sample_rate=sample_rate,
window_size=window_size,
window_step=window_step,
num_channels=num_channels,
upper_band_limit=upper_band_limit,
lower_band_limit=lower_band_limit,
smoothing_bits=smoothing_bits,
even_smoothing=even_smoothing,
odd_smoothing=odd_smoothing,
min_signal_remaining=min_signal_remaining,
enable_pcan=enable_pcan,
pcan_strength=pcan_strength,
pcan_offset=pcan_offset, gain_bits=gain_bits,
enable_log=enable_log, scale_shift=scale_shift,
left_context=left_context,
right_context=right_context,
frame_stride=frame_stride,
zero_padding=zero_padding, out_scale=out_scale,
out_type=out_type, name=name)
except (TypeError, ValueError):
_result = _dispatch.dispatch(
audio_microfrontend, (), dict(audio=audio, sample_rate=sample_rate,
window_size=window_size,
window_step=window_step,
num_channels=num_channels,
upper_band_limit=upper_band_limit,
lower_band_limit=lower_band_limit,
smoothing_bits=smoothing_bits,
even_smoothing=even_smoothing,
odd_smoothing=odd_smoothing,
min_signal_remaining=min_signal_remaining,
enable_pcan=enable_pcan,
pcan_strength=pcan_strength,
pcan_offset=pcan_offset,
gain_bits=gain_bits,
enable_log=enable_log,
scale_shift=scale_shift,
left_context=left_context,
right_context=right_context,
frame_stride=frame_stride,
zero_padding=zero_padding,
out_scale=out_scale,
out_type=out_type, name=name)
)
if _result is not _dispatch.OpDispatcher.NOT_SUPPORTED:
return _result
raise
_result = _outputs[:]
if _execute.must_record_gradient():
_attrs = ("sample_rate", _op._get_attr_int("sample_rate"), "window_size",
_op._get_attr_int("window_size"), "window_step",
_op._get_attr_int("window_step"), "num_channels",
_op._get_attr_int("num_channels"), "upper_band_limit",
_op.get_attr("upper_band_limit"), "lower_band_limit",
_op.get_attr("lower_band_limit"), "smoothing_bits",
_op._get_attr_int("smoothing_bits"), "even_smoothing",
_op.get_attr("even_smoothing"), "odd_smoothing",
_op.get_attr("odd_smoothing"), "min_signal_remaining",
_op.get_attr("min_signal_remaining"), "enable_pcan",
_op._get_attr_bool("enable_pcan"), "pcan_strength",
_op.get_attr("pcan_strength"), "pcan_offset",
_op.get_attr("pcan_offset"), "gain_bits",
_op._get_attr_int("gain_bits"), "enable_log",
_op._get_attr_bool("enable_log"), "scale_shift",
_op._get_attr_int("scale_shift"), "left_context",
_op._get_attr_int("left_context"), "right_context",
_op._get_attr_int("right_context"), "frame_stride",
_op._get_attr_int("frame_stride"), "zero_padding",
_op._get_attr_bool("zero_padding"), "out_scale",
_op._get_attr_int("out_scale"), "out_type",
_op._get_attr_type("out_type"))
_inputs_flat = _op.inputs
_execute.record_gradient(
"AudioMicrofrontend", _inputs_flat, _attrs, _result)
_result, = _result
return _result
AudioMicrofrontend = tf_export("raw_ops.AudioMicrofrontend")(_ops.to_raw_op(audio_microfrontend))
_dispatcher_for_audio_microfrontend = audio_microfrontend._tf_type_based_dispatcher.Dispatch
def audio_microfrontend_eager_fallback(audio: Annotated[Any, _atypes.Int16], sample_rate: int, window_size: int, window_step: int, num_channels: int, upper_band_limit: float, lower_band_limit: float, smoothing_bits: int, even_smoothing: float, odd_smoothing: float, min_signal_remaining: float, enable_pcan: bool, pcan_strength: float, pcan_offset: float, gain_bits: int, enable_log: bool, scale_shift: int, left_context: int, right_context: int, frame_stride: int, zero_padding: bool, out_scale: int, out_type: TV_AudioMicrofrontend_out_type, name, ctx) -> Annotated[Any, TV_AudioMicrofrontend_out_type]:
if sample_rate is None:
sample_rate = 16000
sample_rate = _execute.make_int(sample_rate, "sample_rate")
if window_size is None:
window_size = 25
window_size = _execute.make_int(window_size, "window_size")
if window_step is None:
window_step = 10
window_step = _execute.make_int(window_step, "window_step")
if num_channels is None:
num_channels = 32
num_channels = _execute.make_int(num_channels, "num_channels")
if upper_band_limit is None:
upper_band_limit = 7500
upper_band_limit = _execute.make_float(upper_band_limit, "upper_band_limit")
if lower_band_limit is None:
lower_band_limit = 125
lower_band_limit = _execute.make_float(lower_band_limit, "lower_band_limit")
if smoothing_bits is None:
smoothing_bits = 10
smoothing_bits = _execute.make_int(smoothing_bits, "smoothing_bits")
if even_smoothing is None:
even_smoothing = 0.025
even_smoothing = _execute.make_float(even_smoothing, "even_smoothing")
if odd_smoothing is None:
odd_smoothing = 0.06
odd_smoothing = _execute.make_float(odd_smoothing, "odd_smoothing")
if min_signal_remaining is None:
min_signal_remaining = 0.05
min_signal_remaining = _execute.make_float(min_signal_remaining, "min_signal_remaining")
if enable_pcan is None:
enable_pcan = False
enable_pcan = _execute.make_bool(enable_pcan, "enable_pcan")
if pcan_strength is None:
pcan_strength = 0.95
pcan_strength = _execute.make_float(pcan_strength, "pcan_strength")
if pcan_offset is None:
pcan_offset = 80
pcan_offset = _execute.make_float(pcan_offset, "pcan_offset")
if gain_bits is None:
gain_bits = 21
gain_bits = _execute.make_int(gain_bits, "gain_bits")
if enable_log is None:
enable_log = True
enable_log = _execute.make_bool(enable_log, "enable_log")
if scale_shift is None:
scale_shift = 6
scale_shift = _execute.make_int(scale_shift, "scale_shift")
if left_context is None:
left_context = 0
left_context = _execute.make_int(left_context, "left_context")
if right_context is None:
right_context = 0
right_context = _execute.make_int(right_context, "right_context")
if frame_stride is None:
frame_stride = 1
frame_stride = _execute.make_int(frame_stride, "frame_stride")
if zero_padding is None:
zero_padding = False
zero_padding = _execute.make_bool(zero_padding, "zero_padding")
if out_scale is None:
out_scale = 1
out_scale = _execute.make_int(out_scale, "out_scale")
if out_type is None:
out_type = _dtypes.uint16
out_type = _execute.make_type(out_type, "out_type")
audio = _ops.convert_to_tensor(audio, _dtypes.int16)
_inputs_flat = [audio]
_attrs = ("sample_rate", sample_rate, "window_size", window_size,
"window_step", window_step, "num_channels", num_channels,
"upper_band_limit", upper_band_limit, "lower_band_limit", lower_band_limit,
"smoothing_bits", smoothing_bits, "even_smoothing", even_smoothing,
"odd_smoothing", odd_smoothing, "min_signal_remaining",
min_signal_remaining, "enable_pcan", enable_pcan, "pcan_strength",
pcan_strength, "pcan_offset", pcan_offset, "gain_bits", gain_bits,
"enable_log", enable_log, "scale_shift", scale_shift, "left_context",
left_context, "right_context", right_context, "frame_stride", frame_stride,
"zero_padding", zero_padding, "out_scale", out_scale, "out_type", out_type)
_result = _execute.execute(b"AudioMicrofrontend", 1, inputs=_inputs_flat,
attrs=_attrs, ctx=ctx, name=name)
if _execute.must_record_gradient():
_execute.record_gradient(
"AudioMicrofrontend", _inputs_flat, _attrs, _result)
_result, = _result
return _result
@@ -0,0 +1,110 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""AudioMicrofrontend Op creates filterbanks from audio data."""
from tensorflow.lite.experimental.microfrontend.ops import gen_audio_microfrontend_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import load_library
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import resource_loader
_audio_microfrontend_op = load_library.load_op_library(
resource_loader.get_path_to_datafile("_audio_microfrontend_op.so"))
def audio_microfrontend(audio,
sample_rate=16000,
window_size=25,
window_step=10,
num_channels=32,
upper_band_limit=7500.0,
lower_band_limit=125.0,
smoothing_bits=10,
even_smoothing=0.025,
odd_smoothing=0.06,
min_signal_remaining=0.05,
enable_pcan=True,
pcan_strength=0.95,
pcan_offset=80.0,
gain_bits=21,
enable_log=True,
scale_shift=6,
left_context=0,
right_context=0,
frame_stride=1,
zero_padding=False,
out_scale=1,
out_type=dtypes.uint16):
"""Audio Microfrontend Op.
This Op converts a sequence of audio data into one or more
feature vectors containing filterbanks of the input. The
conversion process uses a lightweight library to perform:
1. A slicing window function
2. Short-time FFTs
3. Filterbank calculations
4. Noise reduction
5. PCAN Auto Gain Control
6. Logarithmic scaling
Args:
audio: 1D Tensor, int16 audio data in temporal ordering.
sample_rate: Integer, the sample rate of the audio in Hz.
window_size: Integer, length of desired time frames in ms.
window_step: Integer, length of step size for the next frame in ms.
num_channels: Integer, the number of filterbank channels to use.
upper_band_limit: Float, the highest frequency included in the filterbanks.
lower_band_limit: Float, the lowest frequency included in the filterbanks.
smoothing_bits: Int, scale up signal by 2^(smoothing_bits) before reduction.
even_smoothing: Float, smoothing coefficient for even-numbered channels.
odd_smoothing: Float, smoothing coefficient for odd-numbered channels.
min_signal_remaining: Float, fraction of signal to preserve in smoothing.
enable_pcan: Bool, enable PCAN auto gain control.
pcan_strength: Float, gain normalization exponent.
pcan_offset: Float, positive value added in the normalization denominator.
gain_bits: Int, number of fractional bits in the gain.
enable_log: Bool, enable logarithmic scaling of filterbanks.
scale_shift: Integer, scale filterbanks by 2^(scale_shift).
left_context: Integer, number of preceding frames to attach to each frame.
right_context: Integer, number of preceding frames to attach to each frame.
frame_stride: Integer, M frames to skip over, where output[n] = frame[n*M].
zero_padding: Bool, if left/right context is out-of-bounds, attach frame of
zeroes. Otherwise, frame[0] or frame[size-1] will be copied.
out_scale: Integer, divide all filterbanks by this number.
out_type: DType, type of the output Tensor, defaults to UINT16.
Returns:
filterbanks: 2D Tensor, each row is a time frame, each column is a channel.
Raises:
ValueError: If the audio tensor is not explicitly a vector.
"""
audio_shape = audio.shape
if audio_shape.ndims is None:
raise ValueError("Input to `AudioMicrofrontend` should have known rank.")
if len(audio_shape) > 1:
audio = array_ops.reshape(audio, [-1])
return gen_audio_microfrontend_op.audio_microfrontend(
audio, sample_rate, window_size, window_step, num_channels,
upper_band_limit, lower_band_limit, smoothing_bits, even_smoothing,
odd_smoothing, min_signal_remaining, enable_pcan, pcan_strength,
pcan_offset, gain_bits, enable_log, scale_shift, left_context,
right_context, frame_stride, zero_padding, out_scale, out_type)
ops.NotDifferentiable("AudioMicrofrontend")
@@ -0,0 +1,66 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: tensorflow/lite/profiling/proto/model_runtime_info.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'tensorflow/lite/profiling/proto/model_runtime_info.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from tensorflow.lite.profiling.proto import profiling_info_pb2 as tensorflow_dot_lite_dot_profiling_dot_proto_dot_profiling__info__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n8tensorflow/lite/profiling/proto/model_runtime_info.proto\x12\x10tflite.profiling\x1a\x34tensorflow/lite/profiling/proto/profiling_info.proto\"_\n\x13ModelRuntimeDetails\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12\x34\n\tsubgraphs\x18\x02 \x03(\x0b\x32!.tflite.profiling.RuntimeSubgraph\"\xb7\x02\n\x0fRuntimeSubgraph\x12\x13\n\x0bsubgraph_id\x18\x01 \x01(\x05\x12%\n\x05\x65\x64ges\x18\x02 \x03(\x0b\x32\x16.tflite.profiling.Edge\x12%\n\x05nodes\x18\x03 \x03(\x0b\x32\x16.tflite.profiling.Node\x12\x1a\n\x0e\x65xecution_plan\x18\x04 \x03(\x05\x42\x02\x10\x01\x12\x45\n\rsubgraph_type\x18\x05 \x01(\x0e\x32..tflite.profiling.RuntimeSubgraph.SubgraphType\x12\x0c\n\x04name\x18\x06 \x01(\t\"P\n\x0cSubgraphType\x12\x14\n\x10UNKNOWN_SUBGRAPH\x10\x00\x12\x13\n\x0fTFLITE_SUBGRAPH\x10\x01\x12\x15\n\x11\x44\x45LEGATE_SUBGRAPH\x10\x02\"\xba\x02\n\x04Node\x12\n\n\x02id\x18\x01 \x01(\x05\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x0c\n\x04type\x18\x03 \x01(\t\x12\x12\n\x06inputs\x18\x04 \x03(\x05\x42\x02\x10\x01\x12\x13\n\x07outputs\x18\x05 \x03(\x05\x42\x02\x10\x01\x12\x19\n\rintermediates\x18\x06 \x03(\x05\x42\x02\x10\x01\x12\x17\n\x0btemporaries\x18\x07 \x03(\x05\x42\x02\x10\x01\x12\x38\n\x0fop_profile_data\x18\n \x01(\x0b\x32\x1f.tflite.profiling.OpProfileData\x12\x46\n\x15\x64\x65legate_node_details\x18\x08 \x01(\x0b\x32%.tflite.profiling.DelegateNodeDetailsH\x00\x12\x1e\n\x14\x64\x65legated_to_node_id\x18\t \x01(\x05H\x00\x42\x0b\n\tnode_info\"R\n\x13\x44\x65legateNodeDetails\x12\x15\n\rdelegate_name\x18\x01 \x01(\t\x12$\n\x18tflite_node_ids_replaced\x18\x02 \x03(\x05\x42\x02\x10\x01\"\x81\x05\n\x04\x45\x64ge\x12\n\n\x02id\x18\x01 \x01(\x05\x12\x0c\n\x04name\x18\x02 \x01(\t\x12\x32\n\tdata_type\x18\x03 \x01(\x0e\x32\x1f.tflite.profiling.Edge.DataType\x12\x11\n\x05shape\x18\x04 \x03(\x05\x42\x02\x10\x01\x12\x17\n\x0f\x61llocation_type\x18\x05 \x01(\t\x12\x36\n\x0blayout_type\x18\x06 \x01(\x0e\x32!.tflite.profiling.Edge.LayoutType\x12\x0c\n\x04size\x18\x07 \x01(\x05\"\x85\x02\n\x08\x44\x61taType\x12\x10\n\x0cUNKNOWN_TYPE\x10\x00\x12\x0b\n\x07\x46LOAT32\x10\x01\x12\t\n\x05INT32\x10\x02\x12\t\n\x05UINT8\x10\x03\x12\t\n\x05INT64\x10\x04\x12\n\n\x06STRING\x10\x05\x12\x08\n\x04\x42OOL\x10\x06\x12\t\n\x05INT16\x10\x07\x12\r\n\tCOMPLEX64\x10\x08\x12\x08\n\x04INT8\x10\t\x12\x0b\n\x07\x46LOAT16\x10\n\x12\x0b\n\x07\x46LOAT64\x10\x0b\x12\x0e\n\nCOMPLEX128\x10\x0c\x12\n\n\x06UINT64\x10\r\x12\x0c\n\x08RESOURCE\x10\x0e\x12\x0b\n\x07VARIANT\x10\x0f\x12\n\n\x06UINT32\x10\x10\x12\n\n\x06UINT16\x10\x11\x12\x08\n\x04INT4\x10\x12\x12\x0c\n\x08\x42\x46LOAT16\x10\x13\"\xb0\x01\n\nLayoutType\x12\x0b\n\x07UNKNOWN\x10\x00\x12\n\n\x06SCALAR\x10\x01\x12\n\n\x06LINEAR\x10\x02\x12\x06\n\x02HW\x10\x03\x12\x07\n\x03\x43HW\x10\x04\x12\x07\n\x03HWC\x10\x05\x12\x08\n\x04OIHW\x10\x06\x12\x08\n\x04OHWI\x10\x07\x12\x08\n\x04IHWO\x10\x08\x12\x08\n\x04IOHW\x10\t\x12\x08\n\x04\x42HWC\x10\n\x12\x08\n\x04HWDC\x10\x0b\x12\t\n\x05\x42HWDC\x10\x0c\x12\x07\n\x03HWD\x10\r\x12\t\n\x05OHWDI\x10\x0e\x12\x08\n\x04HWIO\x10\x0f\x42\x14\n\x10tflite.profilingP\x01')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.lite.profiling.proto.model_runtime_info_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
_globals['DESCRIPTOR']._loaded_options = None
_globals['DESCRIPTOR']._serialized_options = b'\n\020tflite.profilingP\001'
_globals['_RUNTIMESUBGRAPH'].fields_by_name['execution_plan']._loaded_options = None
_globals['_RUNTIMESUBGRAPH'].fields_by_name['execution_plan']._serialized_options = b'\020\001'
_globals['_NODE'].fields_by_name['inputs']._loaded_options = None
_globals['_NODE'].fields_by_name['inputs']._serialized_options = b'\020\001'
_globals['_NODE'].fields_by_name['outputs']._loaded_options = None
_globals['_NODE'].fields_by_name['outputs']._serialized_options = b'\020\001'
_globals['_NODE'].fields_by_name['intermediates']._loaded_options = None
_globals['_NODE'].fields_by_name['intermediates']._serialized_options = b'\020\001'
_globals['_NODE'].fields_by_name['temporaries']._loaded_options = None
_globals['_NODE'].fields_by_name['temporaries']._serialized_options = b'\020\001'
_globals['_DELEGATENODEDETAILS'].fields_by_name['tflite_node_ids_replaced']._loaded_options = None
_globals['_DELEGATENODEDETAILS'].fields_by_name['tflite_node_ids_replaced']._serialized_options = b'\020\001'
_globals['_EDGE'].fields_by_name['shape']._loaded_options = None
_globals['_EDGE'].fields_by_name['shape']._serialized_options = b'\020\001'
_globals['_MODELRUNTIMEDETAILS']._serialized_start=132
_globals['_MODELRUNTIMEDETAILS']._serialized_end=227
_globals['_RUNTIMESUBGRAPH']._serialized_start=230
_globals['_RUNTIMESUBGRAPH']._serialized_end=541
_globals['_RUNTIMESUBGRAPH_SUBGRAPHTYPE']._serialized_start=461
_globals['_RUNTIMESUBGRAPH_SUBGRAPHTYPE']._serialized_end=541
_globals['_NODE']._serialized_start=544
_globals['_NODE']._serialized_end=858
_globals['_DELEGATENODEDETAILS']._serialized_start=860
_globals['_DELEGATENODEDETAILS']._serialized_end=942
_globals['_EDGE']._serialized_start=945
_globals['_EDGE']._serialized_end=1586
_globals['_EDGE_DATATYPE']._serialized_start=1146
_globals['_EDGE_DATATYPE']._serialized_end=1407
_globals['_EDGE_LAYOUTTYPE']._serialized_start=1410
_globals['_EDGE_LAYOUTTYPE']._serialized_end=1586
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: tensorflow/lite/profiling/proto/profiling_info.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'tensorflow/lite/profiling/proto/profiling_info.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n4tensorflow/lite/profiling/proto/profiling_info.proto\x12\x10tflite.profiling\"\xa7\x01\n\x16\x42\x65nchmarkProfilingData\x12\x12\n\nmodel_name\x18\x01 \x01(\t\x12:\n\x0cinit_profile\x18\x02 \x01(\x0b\x32$.tflite.profiling.ModelProfilingData\x12=\n\x0fruntime_profile\x18\x03 \x01(\x0b\x32$.tflite.profiling.ModelProfilingData\"\x9c\x01\n\x12ModelProfilingData\x12\x42\n\x11subgraph_profiles\x18\x01 \x03(\x0b\x32\'.tflite.profiling.SubGraphProfilingData\x12\x42\n\x11\x64\x65legate_profiles\x18\x02 \x03(\x0b\x32\'.tflite.profiling.DelegateProfilingData\"\x80\x01\n\x15SubGraphProfilingData\x12\x15\n\rsubgraph_name\x18\x01 \x01(\t\x12\x16\n\x0esubgraph_index\x18\x02 \x01(\x05\x12\x38\n\x0fper_op_profiles\x18\x03 \x03(\x0b\x32\x1f.tflite.profiling.OpProfileData\"h\n\x15\x44\x65legateProfilingData\x12\x15\n\rdelegate_name\x18\x01 \x01(\t\x12\x38\n\x0fper_op_profiles\x18\x02 \x03(\x0b\x32\x1f.tflite.profiling.OpProfileData\"\x93\x01\n\x0fOpProfilingStat\x12\r\n\x05\x66irst\x18\x01 \x01(\x03\x12\x0c\n\x04last\x18\x02 \x01(\x03\x12\x0b\n\x03\x61vg\x18\x03 \x01(\x03\x12\x0e\n\x06stddev\x18\x04 \x01(\x02\x12\x10\n\x08variance\x18\x05 \x01(\x02\x12\x0b\n\x03min\x18\x06 \x01(\x03\x12\x0b\n\x03max\x18\x07 \x01(\x03\x12\x0b\n\x03sum\x18\x08 \x01(\x03\x12\r\n\x05\x63ount\x18\t \x01(\x03\"\xcf\x01\n\rOpProfileData\x12\x11\n\tnode_type\x18\x01 \x01(\t\x12\x41\n\x16inference_microseconds\x18\x02 \x01(\x0b\x32!.tflite.profiling.OpProfilingStat\x12\x31\n\x06mem_kb\x18\x03 \x01(\x0b\x32!.tflite.profiling.OpProfilingStat\x12\x14\n\x0ctimes_called\x18\x04 \x01(\x03\x12\x0c\n\x04name\x18\x05 \x01(\t\x12\x11\n\trun_order\x18\x06 \x01(\x03\x42\x14\n\x10tflite.profilingP\x01')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.lite.profiling.proto.profiling_info_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
_globals['DESCRIPTOR']._loaded_options = None
_globals['DESCRIPTOR']._serialized_options = b'\n\020tflite.profilingP\001'
_globals['_BENCHMARKPROFILINGDATA']._serialized_start=75
_globals['_BENCHMARKPROFILINGDATA']._serialized_end=242
_globals['_MODELPROFILINGDATA']._serialized_start=245
_globals['_MODELPROFILINGDATA']._serialized_end=401
_globals['_SUBGRAPHPROFILINGDATA']._serialized_start=404
_globals['_SUBGRAPHPROFILINGDATA']._serialized_end=532
_globals['_DELEGATEPROFILINGDATA']._serialized_start=534
_globals['_DELEGATEPROFILINGDATA']._serialized_end=638
_globals['_OPPROFILINGSTAT']._serialized_start=641
_globals['_OPPROFILINGSTAT']._serialized_end=788
_globals['_OPPROFILEDATA']._serialized_start=791
_globals['_OPPROFILEDATA']._serialized_end=998
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,107 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This tool analyzes a TensorFlow Lite graph."""
import os
# pylint: disable=g-import-not-at-top
if not os.path.splitext(__file__)[0].endswith(
os.path.join("tflite_runtime", "analyzer")):
# This file is part of tensorflow package.
from tensorflow.compiler.mlir.lite.python import wrap_converter
from tensorflow.lite.python.analyzer_wrapper import _pywrap_analyzer_wrapper as _analyzer_wrapper
from tensorflow.python.util.tf_export import tf_export as _tf_export
else:
# This file is part of tflite_runtime package.
from tflite_runtime import _pywrap_analyzer_wrapper as _analyzer_wrapper
def _tf_export(*x, **kwargs):
del x, kwargs
return lambda x: x
@_tf_export("lite.experimental.Analyzer")
class ModelAnalyzer():
"""Provides a collection of TFLite model analyzer tools.
Example:
```python
model = tf.keras.applications.MobileNetV3Large()
fb_model = tf.lite.TFLiteConverterV2.from_keras_model(model).convert()
tf.lite.experimental.Analyzer.analyze(model_content=fb_model)
# === TFLite ModelAnalyzer ===
#
# Your TFLite model has 1 subgraph(s). In the subgraph description below,
# T# represents the Tensor numbers. For example, in Subgraph#0, the MUL op
# takes tensor #0 and tensor #19 as input and produces tensor #136 as output.
#
# Subgraph#0 main(T#0) -> [T#263]
# Op#0 MUL(T#0, T#19) -> [T#136]
# Op#1 ADD(T#136, T#18) -> [T#137]
# Op#2 CONV_2D(T#137, T#44, T#93) -> [T#138]
# Op#3 HARD_SWISH(T#138) -> [T#139]
# Op#4 DEPTHWISE_CONV_2D(T#139, T#94, T#24) -> [T#140]
# ...
```
WARNING: Experimental interface, subject to change.
"""
@staticmethod
def analyze(model_path=None,
model_content=None,
gpu_compatibility=False,
**kwargs):
"""Analyzes the given tflite_model with dumping model structure.
This tool provides a way to understand users' TFLite flatbuffer model by
dumping internal graph structure. It also provides additional features
like checking GPU delegate compatibility.
WARNING: Experimental interface, subject to change.
The output format is not guaranteed to stay stable, so don't
write scripts to this.
Args:
model_path: TFLite flatbuffer model path.
model_content: TFLite flatbuffer model object.
gpu_compatibility: Whether to check GPU delegate compatibility.
**kwargs: Experimental keyword arguments to analyze API.
Returns:
Print analyzed report via console output.
"""
if not model_path and not model_content:
raise ValueError("neither `model_path` nor `model_content` is provided")
if model_path:
print(f"=== {model_path} ===\n")
tflite_model = model_path
input_is_filepath = True
else:
print("=== TFLite ModelAnalyzer ===\n")
tflite_model = model_content
input_is_filepath = False
if kwargs.get("experimental_use_mlir", False):
print(
wrap_converter.wrapped_flat_buffer_file_to_mlir(
tflite_model, input_is_filepath
)
)
else:
print(
_analyzer_wrapper.ModelAnalyzer(tflite_model, input_is_filepath,
gpu_compatibility))
@@ -0,0 +1,301 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow Authoring tool package for TFLite compatibility.
WARNING: The package is experimental and subject to change.
This package provides a way to check TFLite compatibility at model authoring
time.
Example:
@tf.lite.experimental.authoring.compatible
@tf.function(input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
return tf.cosh(x)
result = f(tf.constant([0.0]))
> COMPATIBILITY WARNING: op 'tf.Cosh' require(s) "Select TF Ops" for model
> conversion for TensorFlow Lite.
> Op: tf.Cosh
> - tensorflow/python/framework/op_def_library.py:xxx
> - tensorflow/python/ops/gen_math_ops.py:xxx
> - simple_authoring.py:xxx
"""
import functools
from tensorflow.compiler.mlir.lite.metrics import converter_error_data_pb2
# pylint: disable=g-import-not-at-top
from tensorflow.lite.python import convert
from tensorflow.lite.python import lite
from tensorflow.python.util.tf_export import tf_export as _tf_export
_CUSTOM_OPS_HDR = "Custom ops: "
_TF_OPS_HDR = "TF Select ops: "
_AUTHORING_ERROR_HDR = "COMPATIBILITY ERROR"
_AUTHORING_WARNING_HDR = "COMPATIBILITY WARNING"
_FUNC_GRAPH_SRC_PATH = "tensorflow/python/framework/func_graph.py"
class CompatibilityError(Exception):
"""Raised when an error occurs with TFLite compatibility."""
pass
class _Compatible:
"""A decorator class to check TFLite compatibility created by `lite.experimental.authoring.compatible`."""
def __init__(self,
target,
converter_target_spec=None,
converter_allow_custom_ops=None,
raise_exception=False):
"""Initialize the decorator object.
Here is the description of the object variables.
- _func : decorated function.
- _obj_func : for class object, we need to use this object to provide `self`
instance as 1 first argument.
- _verified : whether the compatibility is checked or not.
Args:
target: decorated function.
converter_target_spec : target_spec of TFLite converter parameter.
converter_allow_custom_ops : allow_custom_ops of TFLite converter
parameter.
raise_exception : to raise an exception on compatibility issues.
User need to use get_compatibility_log() to check details.
"""
functools.update_wrapper(self, target)
self._func = target
self._obj_func = None
self._verified = False
self._log_messages = []
self._raise_exception = raise_exception
self._converter_target_spec = converter_target_spec
self._converter_allow_custom_ops = converter_allow_custom_ops
def __get__(self, instance, cls):
"""A Python descriptor interface."""
self._obj_func = self._func.__get__(instance, cls)
return self
def _get_func(self):
"""Returns decorated function object.
For a class method, use self._obj_func to provide `self` instance.
"""
if self._obj_func is not None:
return self._obj_func
else:
return self._func
def __call__(self, *args, **kwargs): # pylint: disable=g-doc-args
"""Calls decorated function object.
Also verifies if the function is compatible with TFLite.
Returns:
A execution result of the decorated function.
"""
if not self._verified:
model = self._get_func()
concrete_func = model.get_concrete_function(*args, **kwargs)
converter = lite.TFLiteConverterV2.from_concrete_functions(
[concrete_func], model)
# Set provided converter parameters
if self._converter_target_spec is not None:
converter.target_spec = self._converter_target_spec
if self._converter_allow_custom_ops is not None:
converter.allow_custom_ops = self._converter_allow_custom_ops
try:
converter.convert()
except convert.ConverterError as err:
self._decode_error(err)
finally:
self._verified = True
return self._get_func()(*args, **kwargs)
def get_concrete_function(self, *args, **kwargs):
"""Returns a concrete function of the decorated function."""
return self._get_func().get_concrete_function(*args, **kwargs)
def _get_location_string(self, location):
"""Dump location of ConveterError.errors.location."""
callstack = []
for single_call in reversed(location.call):
if (location.type ==
converter_error_data_pb2.ConverterErrorData.CALLSITELOC):
callstack.append(
f" - {single_call.source.filename}:{single_call.source.line}")
else:
callstack.append(str(single_call))
callstack_dump = "\n".join(callstack)
return callstack_dump
def _dump_error_details(self, ops, locations):
"""Dump the list of ops and locations."""
for i in range(0, len(ops)):
callstack_dump = self._get_location_string(locations[i])
err_string = f"Op: {ops[i]}\n{callstack_dump}\n"
self._log(err_string)
def _decode_error_legacy(self, err):
"""Parses the given legacy ConverterError for OSS."""
for line in str(err).splitlines():
# Check custom op usage error.
if line.startswith(_CUSTOM_OPS_HDR):
custom_ops = line[len(_CUSTOM_OPS_HDR):]
err_string = (
f"{_AUTHORING_ERROR_HDR}: op '{custom_ops}' is(are) not natively "
"supported by TensorFlow Lite. You need to provide a custom "
"operator. https://www.tensorflow.org/lite/guide/ops_custom")
self._log(err_string)
# Check TensorFlow op usage error.
elif line.startswith(_TF_OPS_HDR):
tf_ops = line[len(_TF_OPS_HDR):]
err_string = (
f"{_AUTHORING_WARNING_HDR}: op '{tf_ops}' require(s) \"Select TF "
"Ops\" for model conversion for TensorFlow Lite. "
"https://www.tensorflow.org/lite/guide/ops_select")
self._log(err_string)
def _decode_converter_error(self, err):
"""Parses the given ConverterError which has detailed error information."""
custom_ops = []
custom_ops_location = []
tf_ops = []
tf_ops_location = []
gpu_not_compatible_ops = []
for err in err.errors:
# Check custom op usage error.
if err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_NEEDS_CUSTOM_OPS:
custom_ops.append(err.operator.name)
custom_ops_location.append(err.location)
# Check TensorFlow op usage error.
elif err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_NEEDS_FLEX_OPS:
tf_ops.append(err.operator.name)
tf_ops_location.append(err.location)
# Check GPU delegate compatibility error.
elif err.error_code == converter_error_data_pb2.ConverterErrorData.ERROR_GPU_NOT_COMPATIBLE:
gpu_not_compatible_ops.append(err.operator.name)
# Log the first line of ConveterError.errors.error_message only
# since the seond line is "Error code: xxxx"
self._log(err.error_message.splitlines()[0])
self._log(self._get_location_string(err.location) + "\n")
else:
# Log other errors.
self._log(f"{_AUTHORING_ERROR_HDR}: {err.error_message}")
self._log(self._get_location_string(err.location) + "\n")
if custom_ops:
custom_ops_str = ", ".join(sorted(custom_ops))
err_string = (
f"{_AUTHORING_ERROR_HDR}: op '{custom_ops_str}' is(are) not natively "
"supported by TensorFlow Lite. You need to provide a custom "
"operator. https://www.tensorflow.org/lite/guide/ops_custom")
self._log(err_string)
self._dump_error_details(custom_ops, custom_ops_location)
if tf_ops:
tf_ops_str = ", ".join(sorted(tf_ops))
err_string = (
f"{_AUTHORING_WARNING_HDR}: op '{tf_ops_str}' require(s) \"Select TF"
" Ops\" for model conversion for TensorFlow Lite. "
"https://www.tensorflow.org/lite/guide/ops_select")
self._log(err_string)
self._dump_error_details(tf_ops, tf_ops_location)
if gpu_not_compatible_ops:
not_compatible_ops_str = ", ".join(sorted(gpu_not_compatible_ops))
err_string = (
f"{_AUTHORING_WARNING_HDR}: op '{not_compatible_ops_str}' aren't "
"compatible with TensorFlow Lite GPU delegate. "
"https://www.tensorflow.org/lite/performance/gpu")
self._log(err_string)
def _decode_error(self, err):
"""Parses the given ConverterError and generates compatibility warnings."""
if hasattr(err, "errors"):
self._decode_converter_error(err)
else:
self._decode_error_legacy(err)
if self._raise_exception and self._log_messages:
raise CompatibilityError(f"CompatibilityException at {repr(self._func)}")
def _log(self, message):
"""Log and print authoring warning / error message."""
self._log_messages.append(message)
print(message)
def get_compatibility_log(self):
"""Returns list of compatibility log messages.
WARNING: This method should only be used for unit tests.
Returns:
The list of log messages by the recent compatibility check.
Raises:
RuntimeError: when the compatibility was NOT checked.
"""
if not self._verified:
raise RuntimeError("target compatibility isn't verified yet")
return self._log_messages
@_tf_export("lite.experimental.authoring.compatible")
def compatible(target=None, converter_target_spec=None, **kwargs):
"""Wraps `tf.function` into a callable function with TFLite compatibility checking.
Example:
```python
@tf.lite.experimental.authoring.compatible
@tf.function(input_signature=[
tf.TensorSpec(shape=[None], dtype=tf.float32)
])
def f(x):
return tf.cosh(x)
result = f(tf.constant([0.0]))
# COMPATIBILITY WARNING: op 'tf.Cosh' require(s) "Select TF Ops" for model
# conversion for TensorFlow Lite.
# Op: tf.Cosh
# - tensorflow/python/framework/op_def_library.py:748
# - tensorflow/python/ops/gen_math_ops.py:2458
# - <stdin>:6
```
WARNING: Experimental interface, subject to change.
Args:
target: A `tf.function` to decorate.
converter_target_spec : target_spec of TFLite converter parameter.
**kwargs: The keyword arguments of the decorator class _Compatible.
Returns:
A callable object of `tf.lite.experimental.authoring._Compatible`.
"""
if target is None:
def wrapper(target):
return _Compatible(target, converter_target_spec, **kwargs)
return wrapper
else:
return _Compatible(target, converter_target_spec, **kwargs)
@@ -0,0 +1,568 @@
import flatbuffers
# automatically generated by the FlatBuffers compiler, do not modify
# namespace: tflite
from flatbuffers.compat import import_numpy
np = import_numpy()
class ModelType(object):
NONE = 0
TF_SAVED_MODEL = 1
KERAS_MODEL = 2
TF_CONCRETE_FUNCTIONS = 3
TF_GRAPH_DEF = 4
TF_SESSION = 5
JAX = 6
PYTORCH = 7
class ModelOptimizationMode(object):
PTQ_FLOAT16 = 1001
PTQ_DYNAMIC_RANGE = 1002
PTQ_FULL_INTEGER = 1003
PTQ_INT16 = 1004
QUANTIZATION_AWARE_TRAINING = 2000
RANDOM_SPARSITY = 3001
BLOCK_SPARSITY = 3002
STRUCTURED_SPARSITY = 3003
class Environment(object):
__slots__ = ['_tab']
@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = Environment()
x.Init(buf, n + offset)
return x
@classmethod
def GetRootAsEnvironment(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# Environment
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)
# Environment
def TensorflowVersion(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.String(o + self._tab.Pos)
return None
# Environment
def ApiVersion(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Uint32Flags, o + self._tab.Pos)
return 0
# Environment
def ModelType(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Int32Flags, o + self._tab.Pos)
return 0
# Environment
def ModelHash(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
return self._tab.Get(flatbuffers.number_types.Uint64Flags, o + self._tab.Pos)
return 0
def EnvironmentStart(builder):
builder.StartObject(4)
def EnvironmentAddTensorflowVersion(builder, tensorflowVersion):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(tensorflowVersion), 0)
def EnvironmentAddApiVersion(builder, apiVersion):
builder.PrependUint32Slot(1, apiVersion, 0)
def EnvironmentAddModelType(builder, modelType):
builder.PrependInt32Slot(2, modelType, 0)
def EnvironmentAddModelHash(builder, modelHash):
builder.PrependUint64Slot(3, modelHash, 0)
def EnvironmentEnd(builder):
return builder.EndObject()
class EnvironmentT(object):
# EnvironmentT
def __init__(self):
self.tensorflowVersion = None # type: str
self.apiVersion = 0 # type: int
self.modelType = 0 # type: int
self.modelHash = 0 # type: int
@classmethod
def InitFromBuf(cls, buf, pos):
environment = Environment()
environment.Init(buf, pos)
return cls.InitFromObj(environment)
@classmethod
def InitFromPackedBuf(cls, buf, pos=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
return cls.InitFromBuf(buf, pos+n)
@classmethod
def InitFromObj(cls, environment):
x = EnvironmentT()
x._UnPack(environment)
return x
# EnvironmentT
def _UnPack(self, environment):
if environment is None:
return
self.tensorflowVersion = environment.TensorflowVersion()
self.apiVersion = environment.ApiVersion()
self.modelType = environment.ModelType()
self.modelHash = environment.ModelHash()
# EnvironmentT
def Pack(self, builder):
if self.tensorflowVersion is not None:
tensorflowVersion = builder.CreateString(self.tensorflowVersion)
EnvironmentStart(builder)
if self.tensorflowVersion is not None:
EnvironmentAddTensorflowVersion(builder, tensorflowVersion)
EnvironmentAddApiVersion(builder, self.apiVersion)
EnvironmentAddModelType(builder, self.modelType)
EnvironmentAddModelHash(builder, self.modelHash)
environment = EnvironmentEnd(builder)
return environment
class SparsityBlockSize(object):
__slots__ = ['_tab']
@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = SparsityBlockSize()
x.Init(buf, n + offset)
return x
@classmethod
def GetRootAsSparsityBlockSize(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# SparsityBlockSize
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)
# SparsityBlockSize
def Values(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
a = self._tab.Vector(o)
return self._tab.Get(flatbuffers.number_types.Uint32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return 0
# SparsityBlockSize
def ValuesAsNumpy(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Uint32Flags, o)
return 0
# SparsityBlockSize
def ValuesLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.VectorLen(o)
return 0
# SparsityBlockSize
def ValuesIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
return o == 0
def SparsityBlockSizeStart(builder):
builder.StartObject(1)
def SparsityBlockSizeAddValues(builder, values):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(values), 0)
def SparsityBlockSizeStartValuesVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def SparsityBlockSizeEnd(builder):
return builder.EndObject()
try:
from typing import List
except:
pass
class SparsityBlockSizeT(object):
# SparsityBlockSizeT
def __init__(self):
self.values = None # type: List[int]
@classmethod
def InitFromBuf(cls, buf, pos):
sparsityBlockSize = SparsityBlockSize()
sparsityBlockSize.Init(buf, pos)
return cls.InitFromObj(sparsityBlockSize)
@classmethod
def InitFromPackedBuf(cls, buf, pos=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
return cls.InitFromBuf(buf, pos+n)
@classmethod
def InitFromObj(cls, sparsityBlockSize):
x = SparsityBlockSizeT()
x._UnPack(sparsityBlockSize)
return x
# SparsityBlockSizeT
def _UnPack(self, sparsityBlockSize):
if sparsityBlockSize is None:
return
if not sparsityBlockSize.ValuesIsNone():
if np is None:
self.values = []
for i in range(sparsityBlockSize.ValuesLength()):
self.values.append(sparsityBlockSize.Values(i))
else:
self.values = sparsityBlockSize.ValuesAsNumpy()
# SparsityBlockSizeT
def Pack(self, builder):
if self.values is not None:
if np is not None and type(self.values) is np.ndarray:
values = builder.CreateNumpyVector(self.values)
else:
SparsityBlockSizeStartValuesVector(builder, len(self.values))
for i in reversed(range(len(self.values))):
builder.PrependUint32(self.values[i])
values = builder.EndVector()
SparsityBlockSizeStart(builder)
if self.values is not None:
SparsityBlockSizeAddValues(builder, values)
sparsityBlockSize = SparsityBlockSizeEnd(builder)
return sparsityBlockSize
class ConversionOptions(object):
__slots__ = ['_tab']
@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = ConversionOptions()
x.Init(buf, n + offset)
return x
@classmethod
def GetRootAsConversionOptions(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# ConversionOptions
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)
# ConversionOptions
def ModelOptimizationModes(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
a = self._tab.Vector(o)
return self._tab.Get(flatbuffers.number_types.Int32Flags, a + flatbuffers.number_types.UOffsetTFlags.py_type(j * 4))
return 0
# ConversionOptions
def ModelOptimizationModesAsNumpy(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.GetVectorAsNumpy(flatbuffers.number_types.Int32Flags, o)
return 0
# ConversionOptions
def ModelOptimizationModesLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
return self._tab.VectorLen(o)
return 0
# ConversionOptions
def ModelOptimizationModesIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
return o == 0
# ConversionOptions
def AllowCustomOps(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
return False
# ConversionOptions
def EnableSelectTfOps(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(8))
if o != 0:
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
return False
# ConversionOptions
def ForceSelectTfOps(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(10))
if o != 0:
return bool(self._tab.Get(flatbuffers.number_types.BoolFlags, o + self._tab.Pos))
return False
# ConversionOptions
def SparsityBlockSizes(self, j):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
if o != 0:
x = self._tab.Vector(o)
x += flatbuffers.number_types.UOffsetTFlags.py_type(j) * 4
x = self._tab.Indirect(x)
obj = SparsityBlockSize()
obj.Init(self._tab.Bytes, x)
return obj
return None
# ConversionOptions
def SparsityBlockSizesLength(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
if o != 0:
return self._tab.VectorLen(o)
return 0
# ConversionOptions
def SparsityBlockSizesIsNone(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(12))
return o == 0
def ConversionOptionsStart(builder):
builder.StartObject(5)
def ConversionOptionsAddModelOptimizationModes(builder, modelOptimizationModes):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(modelOptimizationModes), 0)
def ConversionOptionsStartModelOptimizationModesVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def ConversionOptionsAddAllowCustomOps(builder, allowCustomOps):
builder.PrependBoolSlot(1, allowCustomOps, 0)
def ConversionOptionsAddEnableSelectTfOps(builder, enableSelectTfOps):
builder.PrependBoolSlot(2, enableSelectTfOps, 0)
def ConversionOptionsAddForceSelectTfOps(builder, forceSelectTfOps):
builder.PrependBoolSlot(3, forceSelectTfOps, 0)
def ConversionOptionsAddSparsityBlockSizes(builder, sparsityBlockSizes):
builder.PrependUOffsetTRelativeSlot(4, flatbuffers.number_types.UOffsetTFlags.py_type(sparsityBlockSizes), 0)
def ConversionOptionsStartSparsityBlockSizesVector(builder, numElems):
return builder.StartVector(4, numElems, 4)
def ConversionOptionsEnd(builder):
return builder.EndObject()
try:
from typing import List
except:
pass
class ConversionOptionsT(object):
# ConversionOptionsT
def __init__(self):
self.modelOptimizationModes = None # type: List[int]
self.allowCustomOps = False # type: bool
self.enableSelectTfOps = False # type: bool
self.forceSelectTfOps = False # type: bool
self.sparsityBlockSizes = None # type: List[SparsityBlockSizeT]
@classmethod
def InitFromBuf(cls, buf, pos):
conversionOptions = ConversionOptions()
conversionOptions.Init(buf, pos)
return cls.InitFromObj(conversionOptions)
@classmethod
def InitFromPackedBuf(cls, buf, pos=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
return cls.InitFromBuf(buf, pos+n)
@classmethod
def InitFromObj(cls, conversionOptions):
x = ConversionOptionsT()
x._UnPack(conversionOptions)
return x
# ConversionOptionsT
def _UnPack(self, conversionOptions):
if conversionOptions is None:
return
if not conversionOptions.ModelOptimizationModesIsNone():
if np is None:
self.modelOptimizationModes = []
for i in range(conversionOptions.ModelOptimizationModesLength()):
self.modelOptimizationModes.append(conversionOptions.ModelOptimizationModes(i))
else:
self.modelOptimizationModes = conversionOptions.ModelOptimizationModesAsNumpy()
self.allowCustomOps = conversionOptions.AllowCustomOps()
self.enableSelectTfOps = conversionOptions.EnableSelectTfOps()
self.forceSelectTfOps = conversionOptions.ForceSelectTfOps()
if not conversionOptions.SparsityBlockSizesIsNone():
self.sparsityBlockSizes = []
for i in range(conversionOptions.SparsityBlockSizesLength()):
if conversionOptions.SparsityBlockSizes(i) is None:
self.sparsityBlockSizes.append(None)
else:
sparsityBlockSize_ = SparsityBlockSizeT.InitFromObj(conversionOptions.SparsityBlockSizes(i))
self.sparsityBlockSizes.append(sparsityBlockSize_)
# ConversionOptionsT
def Pack(self, builder):
if self.modelOptimizationModes is not None:
if np is not None and type(self.modelOptimizationModes) is np.ndarray:
modelOptimizationModes = builder.CreateNumpyVector(self.modelOptimizationModes)
else:
ConversionOptionsStartModelOptimizationModesVector(builder, len(self.modelOptimizationModes))
for i in reversed(range(len(self.modelOptimizationModes))):
builder.PrependInt32(self.modelOptimizationModes[i])
modelOptimizationModes = builder.EndVector()
if self.sparsityBlockSizes is not None:
sparsityBlockSizeslist = []
for i in range(len(self.sparsityBlockSizes)):
sparsityBlockSizeslist.append(self.sparsityBlockSizes[i].Pack(builder))
ConversionOptionsStartSparsityBlockSizesVector(builder, len(self.sparsityBlockSizes))
for i in reversed(range(len(self.sparsityBlockSizes))):
builder.PrependUOffsetTRelative(sparsityBlockSizeslist[i])
sparsityBlockSizes = builder.EndVector()
ConversionOptionsStart(builder)
if self.modelOptimizationModes is not None:
ConversionOptionsAddModelOptimizationModes(builder, modelOptimizationModes)
ConversionOptionsAddAllowCustomOps(builder, self.allowCustomOps)
ConversionOptionsAddEnableSelectTfOps(builder, self.enableSelectTfOps)
ConversionOptionsAddForceSelectTfOps(builder, self.forceSelectTfOps)
if self.sparsityBlockSizes is not None:
ConversionOptionsAddSparsityBlockSizes(builder, sparsityBlockSizes)
conversionOptions = ConversionOptionsEnd(builder)
return conversionOptions
class ConversionMetadata(object):
__slots__ = ['_tab']
@classmethod
def GetRootAs(cls, buf, offset=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, offset)
x = ConversionMetadata()
x.Init(buf, n + offset)
return x
@classmethod
def GetRootAsConversionMetadata(cls, buf, offset=0):
"""This method is deprecated. Please switch to GetRootAs."""
return cls.GetRootAs(buf, offset)
# ConversionMetadata
def Init(self, buf, pos):
self._tab = flatbuffers.table.Table(buf, pos)
# ConversionMetadata
def Environment(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(4))
if o != 0:
x = self._tab.Indirect(o + self._tab.Pos)
obj = Environment()
obj.Init(self._tab.Bytes, x)
return obj
return None
# ConversionMetadata
def Options(self):
o = flatbuffers.number_types.UOffsetTFlags.py_type(self._tab.Offset(6))
if o != 0:
x = self._tab.Indirect(o + self._tab.Pos)
obj = ConversionOptions()
obj.Init(self._tab.Bytes, x)
return obj
return None
def ConversionMetadataStart(builder):
builder.StartObject(2)
def ConversionMetadataAddEnvironment(builder, environment):
builder.PrependUOffsetTRelativeSlot(0, flatbuffers.number_types.UOffsetTFlags.py_type(environment), 0)
def ConversionMetadataAddOptions(builder, options):
builder.PrependUOffsetTRelativeSlot(1, flatbuffers.number_types.UOffsetTFlags.py_type(options), 0)
def ConversionMetadataEnd(builder):
return builder.EndObject()
try:
from typing import Optional
except:
pass
class ConversionMetadataT(object):
# ConversionMetadataT
def __init__(self):
self.environment = None # type: Optional[EnvironmentT]
self.options = None # type: Optional[ConversionOptionsT]
@classmethod
def InitFromBuf(cls, buf, pos):
conversionMetadata = ConversionMetadata()
conversionMetadata.Init(buf, pos)
return cls.InitFromObj(conversionMetadata)
@classmethod
def InitFromPackedBuf(cls, buf, pos=0):
n = flatbuffers.encode.Get(flatbuffers.packer.uoffset, buf, pos)
return cls.InitFromBuf(buf, pos+n)
@classmethod
def InitFromObj(cls, conversionMetadata):
x = ConversionMetadataT()
x._UnPack(conversionMetadata)
return x
# ConversionMetadataT
def _UnPack(self, conversionMetadata):
if conversionMetadata is None:
return
if conversionMetadata.Environment() is not None:
self.environment = EnvironmentT.InitFromObj(conversionMetadata.Environment())
if conversionMetadata.Options() is not None:
self.options = ConversionOptionsT.InitFromObj(conversionMetadata.Options())
# ConversionMetadataT
def Pack(self, builder):
if self.environment is not None:
environment = self.environment.Pack(builder)
if self.options is not None:
options = self.options.Pack(builder)
ConversionMetadataStart(builder)
if self.environment is not None:
ConversionMetadataAddEnvironment(builder, environment)
if self.options is not None:
ConversionMetadataAddOptions(builder, options)
conversionMetadata = ConversionMetadataEnd(builder)
return conversionMetadata
@@ -0,0 +1,219 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for collecting TFLite metrics."""
import collections
import enum
import functools
from typing import Text
from tensorflow.compiler.mlir.lite.metrics import converter_error_data_pb2
from tensorflow.lite.python.metrics import metrics
class Component(enum.Enum):
"""Enum class defining name of the converter components."""
# Validate the given input and prepare and optimize TensorFlow Model.
PREPARE_TF_MODEL = "PREPARE_TF_MODEL"
# Convert to TFLite model format.
CONVERT_TF_TO_TFLITE_MODEL = "CONVERT_TF_TO_TFLITE_MODEL"
# RUN quantization and sparsification.
OPTIMIZE_TFLITE_MODEL = "OPTIMIZE_TFLITE_MODEL"
SubComponentItem = collections.namedtuple("SubComponentItem",
["name", "component"])
class SubComponent(SubComponentItem, enum.Enum):
"""Enum class defining name of the converter subcomponents.
This enum only defines the subcomponents in Python, there might be more
subcomponents defined in C++.
"""
def __str__(self):
return self.value.name
@property
def name(self):
return self.value.name
@property
def component(self):
return self.value.component
# The subcomponent name is unspecified.
UNSPECIFIED = SubComponentItem("UNSPECIFIED", None)
# Valid the given input and parameters.
VALIDATE_INPUTS = SubComponentItem("VALIDATE_INPUTS",
Component.PREPARE_TF_MODEL)
# Load GraphDef from SavedModel.
LOAD_SAVED_MODEL = SubComponentItem("LOAD_SAVED_MODEL",
Component.PREPARE_TF_MODEL)
# Convert a SavedModel to frozen graph.
FREEZE_SAVED_MODEL = SubComponentItem("FREEZE_SAVED_MODEL",
Component.PREPARE_TF_MODEL)
# Save a Keras model to SavedModel.
CONVERT_KERAS_TO_SAVED_MODEL = SubComponentItem(
"CONVERT_KERAS_TO_SAVED_MODEL", Component.PREPARE_TF_MODEL)
# Save Concrete functions to SavedModel.
CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL = SubComponentItem(
"CONVERT_CONCRETE_FUNCTIONS_TO_SAVED_MODEL", Component.PREPARE_TF_MODEL)
# Convert a Keras model to a frozen graph.
FREEZE_KERAS_MODEL = SubComponentItem("FREEZE_KERAS_MODEL",
Component.PREPARE_TF_MODEL)
# Replace all the variables with constants in a ConcreteFunction.
FREEZE_CONCRETE_FUNCTION = SubComponentItem("FREEZE_CONCRETE_FUNCTION",
Component.PREPARE_TF_MODEL)
# Run grappler optimization.
OPTIMIZE_TF_MODEL = SubComponentItem("OPTIMIZE_TF_MODEL",
Component.PREPARE_TF_MODEL)
# Convert using the old TOCO converter.
CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER = SubComponentItem(
"CONVERT_GRAPHDEF_USING_DEPRECATED_CONVERTER",
Component.CONVERT_TF_TO_TFLITE_MODEL)
# Convert a GraphDef to TFLite model.
CONVERT_GRAPHDEF = SubComponentItem("CONVERT_GRAPHDEF",
Component.CONVERT_TF_TO_TFLITE_MODEL)
# Convert a SavedModel to TFLite model.
CONVERT_SAVED_MODEL = SubComponentItem("CONVERT_SAVED_MODEL",
Component.CONVERT_TF_TO_TFLITE_MODEL)
# Convert a Jax HLO to TFLite model.
CONVERT_JAX_HLO = SubComponentItem("CONVERT_JAX_HLO",
Component.CONVERT_TF_TO_TFLITE_MODEL)
# Do quantization by the deprecated quantizer.
QUANTIZE_USING_DEPRECATED_QUANTIZER = SubComponentItem(
"QUANTIZE_USING_DEPRECATED_QUANTIZER", Component.OPTIMIZE_TFLITE_MODEL)
# Do calibration.
CALIBRATE = SubComponentItem("CALIBRATE", Component.OPTIMIZE_TFLITE_MODEL)
# Do quantization by MLIR.
QUANTIZE = SubComponentItem("QUANTIZE", Component.OPTIMIZE_TFLITE_MODEL)
# Do sparsification by MLIR.
SPARSIFY = SubComponentItem("SPARSIFY", Component.OPTIMIZE_TFLITE_MODEL)
class ConverterError(Exception):
"""Raised when an error occurs during model conversion."""
def __init__(self, message):
super(ConverterError, self).__init__(message)
self.errors = []
self._parse_error_message(message)
def append_error(self,
error_data: converter_error_data_pb2.ConverterErrorData):
self.errors.append(error_data)
def _parse_error_message(self, message):
"""If the message matches a pattern, assigns the associated error code.
It is difficult to assign an error code to some errrors in MLIR side, Ex:
errors thrown by other components than TFLite or not using mlir::emitError.
This function try to detect them by the error message and assign the
corresponding error code.
Args:
message: The error message of this exception.
"""
error_code_mapping = {
"Failed to functionalize Control Flow V1 ops. Consider using Control "
"Flow V2 ops instead. See https://www.tensorflow.org/api_docs/python/"
"tf/compat/v1/enable_control_flow_v2.":
converter_error_data_pb2.ConverterErrorData
.ERROR_UNSUPPORTED_CONTROL_FLOW_V1,
}
for pattern, error_code in error_code_mapping.items():
if pattern in message:
error_data = converter_error_data_pb2.ConverterErrorData()
error_data.error_message = message
error_data.error_code = error_code
self.append_error(error_data)
return
def convert_phase(component, subcomponent=SubComponent.UNSPECIFIED):
"""The decorator to identify converter component and subcomponent.
Args:
component: Converter component name.
subcomponent: Converter subcomponent name.
Returns:
Forward the result from the wrapped function.
Raises:
ValueError: if component and subcomponent name is not valid.
"""
if component not in Component:
raise ValueError("Given component name not found")
if subcomponent not in SubComponent:
raise ValueError("Given subcomponent name not found")
if (subcomponent != SubComponent.UNSPECIFIED and
subcomponent.component != component):
raise ValueError("component and subcomponent name don't match")
def report_error(error_data: converter_error_data_pb2.ConverterErrorData):
# Always overwrites the component information, but only overwrites the
# subcomponent if it is not available.
error_data.component = component.value
if not error_data.subcomponent:
error_data.subcomponent = subcomponent.name
tflite_metrics = metrics.TFLiteConverterMetrics()
tflite_metrics.set_converter_error(error_data)
def report_error_message(error_message: Text):
error_data = converter_error_data_pb2.ConverterErrorData()
error_data.error_message = error_message
report_error(error_data)
def actual_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
try:
return func(*args, **kwargs)
except ConverterError as converter_error:
if converter_error.errors:
for error_data in converter_error.errors:
report_error(error_data)
else:
report_error_message(str(converter_error))
raise converter_error from None # Re-throws the exception.
except Exception as error:
report_error_message(str(error))
raise error from None # Re-throws the exception.
return wrapper
return actual_decorator
@@ -0,0 +1,186 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to convert SavedModel to frozen GraphDefs."""
from tensorflow.lite.python import util
from tensorflow.lite.python.convert_phase import Component
from tensorflow.lite.python.convert_phase import convert_phase
from tensorflow.lite.python.convert_phase import SubComponent
from tensorflow.python.client import session
from tensorflow.python.framework import ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.saved_model import constants
from tensorflow.python.saved_model import loader
def get_meta_graph_def(saved_model_dir, tag_set):
"""Validate saved_model and extract MetaGraphDef.
Args:
saved_model_dir: saved_model path to convert.
tag_set: Set of tag(s) of the MetaGraphDef to load.
Returns:
The meta_graph_def used for tflite conversion.
Raises:
ValueError: No valid MetaGraphDef for given tag_set.
"""
with session.Session(graph=ops.Graph()) as sess:
return loader.load(sess, tag_set, saved_model_dir)
def get_signature_def(meta_graph, signature_key):
"""Get the signature def from meta_graph with given signature_key.
Args:
meta_graph: meta_graph_def.
signature_key: signature_def in the meta_graph_def.
Returns:
The signature_def used for tflite conversion.
Raises:
ValueError: Given signature_key is not valid for this meta_graph.
"""
signature_def_map = meta_graph.signature_def
signature_def_keys = set(signature_def_map.keys())
logging.info(
"The given SavedModel MetaGraphDef contains SignatureDefs with the "
"following keys: %s", signature_def_keys)
if signature_key not in signature_def_keys:
raise ValueError("No '{}' in the SavedModel\'s SignatureDefs. Possible "
"values are '{}'.".format(signature_key,
",".join(signature_def_keys)))
return signature_def_map[signature_key]
def get_inputs_outputs(signature_def):
"""Get inputs and outputs from SignatureDef.
Args:
signature_def: SignatureDef in the meta_graph_def for conversion.
Returns:
The inputs and outputs in the graph for conversion.
"""
inputs_tensor_info = signature_def.inputs
outputs_tensor_info = signature_def.outputs
def gather_names(tensor_info):
return [tensor_info[key].name for key in tensor_info]
inputs = gather_names(inputs_tensor_info)
outputs = gather_names(outputs_tensor_info)
return inputs, outputs
def _get_tensors(graph, signature_def_tensor_names=None,
user_tensor_names=None):
"""Gets the tensors associated with the tensor names.
Either signature_def_tensor_names or user_tensor_names should be provided. If
the user provides tensors, the tensors associated with the user provided
tensor names are provided. Otherwise, the tensors associated with the names in
the SignatureDef are provided.
Args:
graph: GraphDef representing graph.
signature_def_tensor_names: Tensor names stored in either the inputs or
outputs of a SignatureDef. (default None)
user_tensor_names: Tensor names provided by the user. (default None)
Returns:
List of tensors.
Raises:
ValueError:
signature_def_tensors and user_tensor_names are undefined or empty.
user_tensor_names are not valid.
"""
tensors = []
if user_tensor_names:
# Sort the tensor names.
user_tensor_names = sorted(user_tensor_names)
tensors = util.get_tensors_from_tensor_names(graph, user_tensor_names)
elif signature_def_tensor_names:
tensors = [
graph.get_tensor_by_name(name)
for name in sorted(signature_def_tensor_names)
]
else:
# Throw ValueError if signature_def_tensors and user_tensor_names are both
# either undefined or empty.
raise ValueError(
"Specify either signature_def_tensor_names or user_tensor_names")
return tensors
@convert_phase(Component.PREPARE_TF_MODEL, SubComponent.FREEZE_SAVED_MODEL)
def freeze_saved_model(saved_model_dir, input_arrays, input_shapes,
output_arrays, tag_set, signature_key):
"""Converts a SavedModel to a frozen graph.
Args:
saved_model_dir: SavedModel directory to convert.
input_arrays: List of input tensors to freeze graph with. Uses input arrays
from SignatureDef when none are provided.
input_shapes: Dict of strings representing input tensor names to list of
integers representing input shapes (e.g., {"foo": : [1, 16, 16, 3]}).
Automatically determined when input shapes is None (e.g., {"foo" : None}).
output_arrays: List of output tensors to freeze graph with. Uses output
arrays from SignatureDef when none are provided.
tag_set: Set of tags identifying the MetaGraphDef within the SavedModel to
analyze. All tags in the tag set must be present.
signature_key: Key identifying SignatureDef containing inputs and outputs.
Returns:
frozen_graph_def: Frozen GraphDef.
in_tensors: List of input tensors for the graph.
out_tensors: List of output tensors for the graph.
graph: `Graph` object.
Raises:
ValueError:
SavedModel doesn't contain a MetaGraphDef identified by tag_set.
signature_key is not in the MetaGraphDef.
assets/ directory is in the MetaGraphDef.
input_shapes does not match the length of input_arrays.
input_arrays or output_arrays are not valid.
"""
# Read SignatureDef.
meta_graph = get_meta_graph_def(saved_model_dir, tag_set)
signature_def = get_signature_def(meta_graph, signature_key)
inputs, outputs = get_inputs_outputs(signature_def)
# Check SavedModel for assets directory.
collection_def = meta_graph.collection_def
if constants.ASSETS_KEY in collection_def:
raise ValueError("SavedModels with assets/ directory are not supported.")
graph = ops.Graph()
with session.Session(graph=graph) as sess:
loader.load(sess, meta_graph.meta_info_def.tags, saved_model_dir)
# Gets input and output tensors.
# TODO(zhixianyan): Use TFLite supported Op list to filter outputs.
in_tensors = _get_tensors(graph, inputs, input_arrays)
out_tensors = _get_tensors(graph, outputs, output_arrays)
util.set_tensor_shapes(in_tensors, input_shapes)
frozen_graph_def = util.freeze_graph(sess, in_tensors, out_tensors)
return frozen_graph_def, in_tensors, out_tensors, sess.graph
@@ -0,0 +1,47 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
class InterpreterWrapper:
def __init__(self, *args, **kwargs) -> None: ...
def AllocateTensors(self, subgraph_index: int = ...) -> object: ...
def GetSignatureDefs(self) -> object: ...
def GetSubgraphIndexFromSignature(self, arg0: str) -> object: ...
def GetTensor(self, tensor_index: int, subgraph_index: int = ...) -> object: ...
def InputIndices(self) -> object: ...
def Invoke(self, subgraph_index: int = ...) -> object: ...
def ModifyGraphWithDelegate(self, arg0: int) -> object: ...
def NodeInputs(self, arg0: int) -> object: ...
def NodeName(self, arg0: int) -> str: ...
def NodeOutputs(self, arg0: int) -> object: ...
def NumNodes(self) -> int: ...
def NumSubgraphs(self) -> int: ...
def NumTensors(self, arg0: int) -> int: ...
def OutputIndices(self) -> object: ...
def ResetVariableTensors(self) -> object: ...
def ResizeInputTensor(self, i: int, value: object, strict: bool, subgraph_index: int = ...) -> object: ...
def SetNumThreads(self, arg0: int) -> object: ...
def SetTensor(self, i: int, value: object, subgraph_index: int = ...) -> object: ...
def TensorName(self, arg0: int, arg1: int) -> str: ...
def TensorQuantization(self, arg0: int, arg1: int) -> object: ...
def TensorQuantizationParameters(self, arg0: int, arg1: int) -> object: ...
def TensorSize(self, arg0: int, arg1: int) -> object: ...
def TensorSizeSignature(self, arg0: int, arg1: int) -> object: ...
def TensorSparsityParameters(self, arg0: int, arg1: int) -> object: ...
def TensorType(self, arg0: int, arg1: int) -> object: ...
def interpreter(self) -> int: ...
def tensor(self, base_object: object, tensor_index: int, subgraph_index: int = ...) -> object: ...
def CreateWrapperFromBuffer(*args, **kwargs): ...
def CreateWrapperFromFile(*args, **kwargs): ...
@@ -0,0 +1,83 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Constants for TFLite."""
from tensorflow.compiler.mlir.lite import converter_flags_pb2 as _converter_flags_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.util.all_util import remove_undocumented
from tensorflow.python.util.tf_export import tf_export as _tf_export
FLOAT = dtypes.float32
FLOAT16 = dtypes.float16
INT32 = dtypes.int32
INT64 = dtypes.int64
STRING = dtypes.string
QUANTIZED_UINT8 = dtypes.uint8
INT8 = dtypes.int8
INT16 = dtypes.int16
COMPLEX64 = dtypes.complex64
TENSORFLOW_GRAPHDEF = _converter_flags_pb2.TENSORFLOW_GRAPHDEF
TFLITE = _converter_flags_pb2.TFLITE
GRAPHVIZ_DOT = _converter_flags_pb2.GRAPHVIZ_DOT
UNSET = _converter_flags_pb2.ConverterFlags.ModelOriginFramework.Name(
_converter_flags_pb2.ConverterFlags.UNSET
)
TENSORFLOW = _converter_flags_pb2.ConverterFlags.ModelOriginFramework.Name(
_converter_flags_pb2.ConverterFlags.TENSORFLOW
)
KERAS = _converter_flags_pb2.ConverterFlags.ModelOriginFramework.Name(
_converter_flags_pb2.ConverterFlags.KERAS
)
JAX = _converter_flags_pb2.ConverterFlags.ModelOriginFramework.Name(
_converter_flags_pb2.ConverterFlags.JAX
)
PYTORCH = _converter_flags_pb2.ConverterFlags.ModelOriginFramework.Name(
_converter_flags_pb2.ConverterFlags.PYTORCH
)
_tf_export(v1=["lite.constants.FLOAT"]).export_constant(__name__, "FLOAT")
_tf_export(v1=["lite.constants.FLOAT16"]).export_constant(__name__, "FLOAT16")
_tf_export(v1=["lite.constants.INT32"]).export_constant(__name__, "INT32")
_tf_export(v1=["lite.constants.INT64"]).export_constant(__name__, "INT64")
_tf_export(v1=["lite.constants.STRING"]).export_constant(__name__, "STRING")
_tf_export(v1=["lite.constants.QUANTIZED_UINT8"]).export_constant(
__name__, "QUANTIZED_UINT8")
_tf_export(v1=["lite.constants.INT8"]).export_constant(__name__, "INT8")
_tf_export(v1=["lite.constants.INT16"]).export_constant(__name__, "INT16")
_tf_export(v1=["lite.constants.TFLITE"]).export_constant(__name__, "TFLITE")
_tf_export(v1=["lite.constants.GRAPHVIZ_DOT"]).export_constant(
__name__, "GRAPHVIZ_DOT")
_allowed_symbols = [
"FLOAT",
"FLOAT16",
"INT32",
"INT64",
"STRING",
"QUANTIZED_UINT8",
"INT8",
"INT16",
"COMPLEX64",
"TENSORFLOW_GRAPHDEF",
"TFLITE",
"GRAPHVIZ_DOT",
"UNSET",
"TENSORFLOW",
"KERAS",
"JAX",
"PYTORCH",
]
remove_undocumented(__name__, _allowed_symbols)
@@ -0,0 +1,18 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
class MetricsWrapper:
def __init__(self, arg0: str) -> None: ...
def ExportMetrics(self) -> object: ...
@@ -0,0 +1,70 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python TFLite metrics helper."""
import os
from typing import Optional, Text
# pylint: disable=g-import-not-at-top
if not os.path.splitext(__file__)[0].endswith(
os.path.join('tflite_runtime', 'metrics_portable')):
# This file is part of tensorflow package.
from tensorflow.lite.python.metrics import metrics_interface # type: ignore
else:
# This file is part of tflite_runtime package.
from tflite_runtime import metrics_interface # type: ignore
# pylint: enable=g-import-not-at-top
class TFLiteMetrics(metrics_interface.TFLiteMetricsInterface):
"""TFLite metrics helper."""
def __init__(self,
model_hash: Optional[Text] = None,
model_path: Optional[Text] = None) -> None:
pass
def increase_counter_debugger_creation(self):
pass
def increase_counter_interpreter_creation(self):
pass
def increase_counter_converter_attempt(self):
pass
def increase_counter_converter_success(self):
pass
def set_converter_param(self, name, value):
pass
def set_converter_error(self, error_data):
pass
def set_converter_latency(self, value):
pass
class TFLiteConverterMetrics(TFLiteMetrics):
"""Similar to TFLiteMetrics but specialized for converter."""
def __del__(self):
pass
def set_export_required(self):
pass
def export_metrics(self):
pass
@@ -0,0 +1,48 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python TFLite metrics helper interface."""
import abc
class TFLiteMetricsInterface(metaclass=abc.ABCMeta):
"""Abstract class for TFLiteMetrics."""
@abc.abstractmethod
def increase_counter_debugger_creation(self):
raise NotImplementedError
@abc.abstractmethod
def increase_counter_interpreter_creation(self):
raise NotImplementedError
@abc.abstractmethod
def increase_counter_converter_attempt(self):
raise NotImplementedError
@abc.abstractmethod
def increase_counter_converter_success(self):
raise NotImplementedError
@abc.abstractmethod
def set_converter_param(self, name, value):
raise NotImplementedError
@abc.abstractmethod
def set_converter_error(self, error_data):
raise NotImplementedError
@abc.abstractmethod
def set_converter_latency(self, value):
raise NotImplementedError
@@ -0,0 +1,34 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Stub to make pywrap metrics wrapper accessible."""
from tensorflow.compiler.mlir.lite.metrics import converter_error_data_pb2
from tensorflow.compiler.mlir.lite.python import wrap_converter
from tensorflow.lite.python.metrics._pywrap_tensorflow_lite_metrics_wrapper import MetricsWrapper # pylint: disable=unused-import
def retrieve_collected_errors():
"""Returns and clears the list of collected errors in ErrorCollector.
The RetrieveCollectedErrors function in C++ returns a list of serialized proto
messages. This function will convert them to ConverterErrorData instances.
Returns:
A list of ConverterErrorData.
"""
serialized_message_list = wrap_converter.wrapped_retrieve_collected_errors()
return list(
map(converter_error_data_pb2.ConverterErrorData.FromString,
serialized_message_list))
@@ -0,0 +1,38 @@
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
from typing import Callable, overload
class CalibrationWrapper:
def __init__(self, arg0: object, arg1: list[str], arg2: list[Callable[[int], None]]) -> None: ...
def Calibrate(self) -> object: ...
@overload
def FeedTensor(self, arg0: object, arg1: str) -> object: ...
@overload
def FeedTensor(self, arg0: object) -> object: ...
@overload
def Prepare(self, arg0: object, arg1: str) -> object: ...
@overload
def Prepare(self, arg0: object) -> object: ...
@overload
def Prepare(self, arg0: str) -> object: ...
@overload
def Prepare(self) -> object: ...
@overload
def QuantizeModel(self, arg0: int, arg1: int, arg2: bool, arg3: int, arg4: int, arg5: bool, arg6: bool) -> object: ...
@overload
def QuantizeModel(self, arg0: int, arg1: int, arg2: bool, arg3: str) -> object: ...
def AddIntermediateTensors(arg0: object) -> object: ...
@@ -0,0 +1,259 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python wrapper for post training quantization with calibration."""
import numpy as np
from tensorflow.lite.python.convert_phase import Component
from tensorflow.lite.python.convert_phase import convert_phase
from tensorflow.lite.python.convert_phase import SubComponent
from tensorflow.lite.python.interpreter import Interpreter
from tensorflow.python.framework import dtypes
from tensorflow.python.util.lazy_loader import LazyLoader
# Lazy load since some of the performance benchmark skylark rules
# break dependencies. Must use double quotes to match code internal rewrite
# rule.
_calibration_wrapper = LazyLoader(
"_calibration_wrapper",
globals(),
(
"tensorflow.lite.python.optimize."
"_pywrap_tensorflow_lite_calibration_wrapper"
),
)
def add_intermediate_tensors(model_content):
"""Adds intermediate tensors to fused op if needed."""
return _calibration_wrapper.AddIntermediateTensors(model_content)
class Calibrator:
"""Calibrates a floating point model and then quantizes it.
This is an internal class, not a public interface.
"""
def __init__(
self,
model_content,
custom_op_registerers_by_name=None,
custom_op_registerers_by_func=None,
):
"""Constructor.
Args:
model_content: Content of a TF-Lite Flatbuffer file.
custom_op_registerers_by_name: List of str (symbol names) that take a
pointer to a MutableOpResolver and register custom ops.
custom_op_registerers_by_func: List of functions that take a pointer to a
MutableOpResolver and register custom ops.
Raises:
ValueError: If the calibrator was unable to open the model.
"""
if not model_content:
raise ValueError("`model_content` must be specified.")
if custom_op_registerers_by_name is None:
custom_op_registerers_by_name = []
if custom_op_registerers_by_func is None:
custom_op_registerers_by_func = []
try:
self._calibrator = _calibration_wrapper.CalibrationWrapper(
model_content,
custom_op_registerers_by_name,
custom_op_registerers_by_func,
)
self._model_content = model_content
except Exception as e:
raise ValueError("Failed to parse the model: %s." % e)
if not self._calibrator:
raise ValueError("Failed to parse the model.")
self._interpreter = None
def _create_input_array_from_dict(self, signature_key, inputs):
input_array = []
signature_runner = self._interpreter.get_signature_runner(signature_key)
input_details = sorted(
signature_runner.get_input_details().items(),
key=lambda item: item[1]["index"],
)
for input_name, _ in input_details:
input_array.append(inputs[input_name])
return input_array
def _feed_tensors(self, dataset_gen, resize_input):
"""Feed tensors to the calibrator."""
initialized = {}
for sample in dataset_gen():
if isinstance(sample, tuple):
if not isinstance(sample[1], dict):
raise ValueError(
"You need to provide either a dictionary with input "
"names and values in the second argument in the "
"tuple"
)
# Convert signature based inputs to the tensor index based data.
if self._interpreter is None:
self._interpreter = Interpreter(model_content=self._model_content)
signature_key = sample[0]
input_array = self._create_input_array_from_dict(
signature_key, sample[1]
)
elif isinstance(sample, dict):
# Convert signature based inputs to the tensor index based data.
if self._interpreter is None:
self._interpreter = Interpreter(model_content=self._model_content)
signature_key = None
input_array = self._create_input_array_from_dict(None, sample)
elif isinstance(sample, list):
signature_key = None
input_array = sample
else:
raise ValueError(
"You need to provide either a dictionary with input "
"names and values, a tuple with signature key and a "
"dictionary with input names and values, or an array "
"with input values in the order of input tensors of "
"the graph in the representative_dataset function. "
"Unsupported value from dataset: {}.".format(sample)
)
if signature_key not in initialized:
initialized[signature_key] = True
if resize_input:
if signature_key is not None:
self._calibrator.Prepare(
[list(s.shape) for s in input_array], signature_key
)
else:
self._calibrator.Prepare([list(s.shape) for s in input_array])
else:
if signature_key is not None:
self._calibrator.Prepare(signature_key)
else:
self._calibrator.Prepare()
if signature_key is not None:
self._calibrator.FeedTensor(input_array, signature_key)
else:
self._calibrator.FeedTensor(input_array)
@convert_phase(
Component.OPTIMIZE_TFLITE_MODEL,
SubComponent.QUANTIZE_USING_DEPRECATED_QUANTIZER,
)
def calibrate_and_quantize(
self,
dataset_gen,
input_type,
output_type,
allow_float,
activations_type=dtypes.int8,
bias_type=dtypes.int32,
resize_input=True,
disable_per_channel=False,
disable_per_channel_quantization_for_dense_layers=False,
):
"""Calibrates the model with specified generator and then quantizes it.
The input shapes of the calibrator are resized with the calibration data if
`resize_input` is set.
Returns:
A quantized model.
Args:
dataset_gen: A generator that generates calibration samples.
input_type: A tf.dtype representing the desired real-value input type.
output_type: A tf.dtype representing the desired real-value output type.
allow_float: A boolean. False if the resulting model cannot perform float
computation, useful when targeting an integer-only backend. If False, an
error will be thrown if an operation cannot be quantized, otherwise the
model will fallback to float ops.
activations_type: A tf.dtype representing the desired type for
activations.
bias_type: A tf.dtype representing the desired type for bias.
resize_input: A boolean. True if the shape of the sample data is different
from the input.
disable_per_channel: A boolean. True if disabling per-channel
quantization.
disable_per_channel_quantization_for_dense_layers: A boolean. True if
disabling per-channel quantization only in Dense layers.
"""
self._feed_tensors(dataset_gen, resize_input)
return self._calibrator.QuantizeModel(
np.dtype(input_type.as_numpy_dtype()).num,
np.dtype(output_type.as_numpy_dtype()).num,
allow_float,
np.dtype(activations_type.as_numpy_dtype()).num,
np.dtype(bias_type.as_numpy_dtype()).num,
disable_per_channel,
disable_per_channel_quantization_for_dense_layers,
)
@convert_phase(
Component.OPTIMIZE_TFLITE_MODEL,
SubComponent.QUANTIZE_USING_DEPRECATED_QUANTIZER,
)
def calibrate_and_quantize_single(
self,
dataset_gen,
input_type,
output_type,
allow_float,
op_output_name,
resize_input=True,
):
"""Calibrates the model with specified generator and then quantizes it.
Only the single op with output op_output_name will be quantized.
The input shapes of the calibrator are resized with the calibration data.
Returns:
A quantized model.
Args:
dataset_gen: A generator that generates calibration samples.
input_type: A tf.dtype representing the desired real-value input type.
output_type: A tf.dtype representing the desired real-value output type.
allow_float: A boolean. False if the resulting model cannot perform float
computation, useful when targeting an integer-only backend. If False, an
error will be thrown if an operation cannot be quantized, otherwise the
model will fallback to float ops.
op_output_name: A string, only this op will be quantized.
resize_input: A boolean. True if the shape of the sample data is different
from the input.
"""
self._feed_tensors(dataset_gen, resize_input)
return self._calibrator.QuantizeModel(
np.dtype(input_type.as_numpy_dtype()).num,
np.dtype(output_type.as_numpy_dtype()).num,
allow_float,
op_output_name,
)
@convert_phase(Component.OPTIMIZE_TFLITE_MODEL, SubComponent.CALIBRATE)
def calibrate(self, dataset_gen):
"""Calibrates the model with specified generator.
Returns:
A model with min and max calibration stats.
Args:
dataset_gen: A generator that generates calibration samples.
"""
self._feed_tensors(dataset_gen, resize_input=True)
return self._calibrator.Calibrate()
@@ -0,0 +1,45 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Schema utilities to get builtin code from operator code."""
from tensorflow.python.util import all_util
def get_builtin_code_from_operator_code(opcode):
"""Return the builtin code of the given operator code.
The following method is introduced to resolve op builtin code shortage
problem. The new builtin operator will be assigned to the extended builtin
code field in the flatbuffer schema. Those methods helps to hide builtin code
details.
Args:
opcode: Operator code.
Returns:
The builtin code of the given operator code.
"""
# Access BuiltinCode() method first if available.
if hasattr(opcode, 'BuiltinCode') and callable(opcode.BuiltinCode):
return max(opcode.BuiltinCode(), opcode.DeprecatedBuiltinCode())
return max(opcode.builtinCode, opcode.deprecatedBuiltinCode)
_allowed_symbols = [
'get_builtin_code_from_operator_code',
]
all_util.remove_undocumented(__name__, _allowed_symbols)
@@ -0,0 +1,696 @@
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python command line interface for converting TF models to TFLite models."""
import argparse
import os
import sys
import warnings
from absl import app
import tensorflow as tf
from tensorflow.lite.python import lite
from tensorflow.lite.python.convert import register_custom_opdefs
from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
from tensorflow.lite.toco.logging import gen_html
from tensorflow.python import tf2
from tensorflow.python.framework import dtypes
from tensorflow.python.platform import gfile
from tensorflow.python.util import keras_deps
# Needed to enable TF2 by default.
_ = tf.keras.models.save_model # ensure necessary imports are executed
def _parse_array(values, type_fn=str):
if values is not None:
return [type_fn(val) for val in values.split(",") if val]
return None
def _parse_set(values):
if values is not None:
return set([item for item in values.split(",") if item])
return None
def _parse_inference_type(value, flag):
"""Converts the inference type to the value of the constant.
Args:
value: str representing the inference type.
flag: str representing the flag name.
Returns:
tf.dtype.
Raises:
ValueError: Unsupported value.
"""
if value == "FLOAT":
return dtypes.float32
if value == "INT8":
return dtypes.int8
if value == "UINT8" or value == "QUANTIZED_UINT8":
return dtypes.uint8
raise ValueError(
"Unsupported value for `{}` flag. Expected FLOAT, INT8, UINT8, or "
"QUANTIZED_UINT8 instead got {}.".format(flag, value))
class _ParseBooleanFlag(argparse.Action):
"""Helper class to parse boolean flag that optionally accepts truth value."""
def __init__(self, option_strings, dest, nargs=None, **kwargs):
if nargs != "?":
# This should never happen. This class is only used once below with
# nargs="?".
raise ValueError(
"This parser only supports nargs='?' (0 or 1 additional arguments)")
super(_ParseBooleanFlag, self).__init__(
option_strings, dest, nargs=nargs, **kwargs)
def __call__(self, parser, namespace, values, option_string=None):
if values is None:
# Handling `--boolean_flag`.
# Without additional arguments, it implies true.
flag_value = True
elif values.lower() == "true":
# Handling `--boolean_flag=true`.
# (Case insensitive after the equal sign)
flag_value = True
elif values.lower() == "false":
# Handling `--boolean_flag=false`.
# (Case insensitive after the equal sign)
flag_value = False
else:
raise ValueError("Invalid argument to --{}. Must use flag alone,"
" or specify true/false.".format(self.dest))
setattr(namespace, self.dest, flag_value)
def _get_tflite_converter(flags):
"""Makes a TFLiteConverter object based on the flags provided.
Args:
flags: argparse.Namespace object containing TFLite flags.
Returns:
TFLiteConverter object.
Raises:
ValueError: Invalid flags.
"""
# Parse input and output arrays.
input_arrays = _parse_array(flags.input_arrays)
input_shapes = None
if flags.input_shapes:
input_shapes_list = [
_parse_array(shape, type_fn=int)
for shape in flags.input_shapes.split(":")
]
input_shapes = dict(list(zip(input_arrays, input_shapes_list)))
output_arrays = _parse_array(flags.output_arrays)
converter_kwargs = {
"input_arrays": input_arrays,
"input_shapes": input_shapes,
"output_arrays": output_arrays
}
# Create TFLiteConverter.
if flags.graph_def_file:
converter_fn = lite.TFLiteConverter.from_frozen_graph
converter_kwargs["graph_def_file"] = flags.graph_def_file
elif flags.saved_model_dir:
converter_fn = lite.TFLiteConverter.from_saved_model
converter_kwargs["saved_model_dir"] = flags.saved_model_dir
converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set)
converter_kwargs["signature_key"] = flags.saved_model_signature_key
elif flags.keras_model_file:
converter_fn = lite.TFLiteConverter.from_keras_model_file
converter_kwargs["model_file"] = flags.keras_model_file
else:
raise ValueError("--graph_def_file, --saved_model_dir, or "
"--keras_model_file must be specified.")
return converter_fn(**converter_kwargs)
def _convert_tf1_model(flags):
"""Calls function to convert the TensorFlow 1.X model into a TFLite model.
Args:
flags: argparse.Namespace object.
Raises:
ValueError: Invalid flags.
"""
# Register custom opdefs before converter object creation.
if flags.custom_opdefs:
register_custom_opdefs(_parse_array(flags.custom_opdefs))
# Create converter.
converter = _get_tflite_converter(flags)
if flags.inference_type:
converter.inference_type = _parse_inference_type(flags.inference_type,
"inference_type")
if flags.inference_input_type:
converter.inference_input_type = _parse_inference_type(
flags.inference_input_type, "inference_input_type")
if flags.output_format:
converter.output_format = _toco_flags_pb2.FileFormat.Value(
flags.output_format)
if flags.mean_values and flags.std_dev_values:
input_arrays = converter.get_input_arrays()
std_dev_values = _parse_array(flags.std_dev_values, type_fn=float)
# In quantized inference, mean_value has to be integer so that the real
# value 0.0 is exactly representable.
if converter.inference_type == dtypes.float32:
mean_values = _parse_array(flags.mean_values, type_fn=float)
else:
mean_values = _parse_array(flags.mean_values, type_fn=int)
quant_stats = list(zip(mean_values, std_dev_values))
if ((not flags.input_arrays and len(input_arrays) > 1) or
(len(input_arrays) != len(quant_stats))):
raise ValueError("Mismatching --input_arrays, --std_dev_values, and "
"--mean_values. The flags must have the same number of "
"items. The current input arrays are '{0}'. "
"--input_arrays must be present when specifying "
"--std_dev_values and --mean_values with multiple input "
"tensors in order to map between names and "
"values.".format(",".join(input_arrays)))
converter.quantized_input_stats = dict(list(zip(input_arrays, quant_stats)))
if (flags.default_ranges_min is not None) and (flags.default_ranges_max is
not None):
converter.default_ranges_stats = (flags.default_ranges_min,
flags.default_ranges_max)
if flags.drop_control_dependency:
converter.drop_control_dependency = flags.drop_control_dependency
if flags.reorder_across_fake_quant:
converter.reorder_across_fake_quant = flags.reorder_across_fake_quant
if flags.change_concat_input_ranges:
converter.change_concat_input_ranges = (
flags.change_concat_input_ranges == "TRUE")
if flags.allow_custom_ops:
converter.allow_custom_ops = flags.allow_custom_ops
if flags.target_ops:
ops_set_options = lite.OpsSet.get_options()
converter.target_spec.supported_ops = set()
for option in flags.target_ops.split(","):
if option not in ops_set_options:
raise ValueError("Invalid value for --target_ops. Options: "
"{0}".format(",".join(ops_set_options)))
converter.target_spec.supported_ops.add(lite.OpsSet(option))
if flags.experimental_select_user_tf_ops:
if lite.OpsSet.SELECT_TF_OPS not in converter.target_spec.supported_ops:
raise ValueError("--experimental_select_user_tf_ops can only be set if "
"--target_ops contains SELECT_TF_OPS.")
user_op_set = set()
for op_name in flags.experimental_select_user_tf_ops.split(","):
user_op_set.add(op_name)
converter.target_spec.experimental_select_user_tf_ops = list(user_op_set)
if flags.post_training_quantize:
converter.optimizations = [lite.Optimize.DEFAULT]
if converter.inference_type != dtypes.float32:
print("--post_training_quantize quantizes a graph of inference_type "
"FLOAT. Overriding inference_type to FLOAT.")
converter.inference_type = dtypes.float32
if flags.quantize_to_float16:
converter.target_spec.supported_types = [dtypes.float16]
if not flags.post_training_quantize:
print("--quantize_to_float16 will only take effect with the "
"--post_training_quantize flag enabled.")
if flags.dump_graphviz_dir:
converter.dump_graphviz_dir = flags.dump_graphviz_dir
if flags.dump_graphviz_video:
converter.dump_graphviz_vode = flags.dump_graphviz_video
if flags.conversion_summary_dir:
converter.conversion_summary_dir = flags.conversion_summary_dir
converter.experimental_new_converter = flags.experimental_new_converter
if flags.experimental_new_quantizer is not None:
converter.experimental_new_quantizer = flags.experimental_new_quantizer
# Convert model.
output_data = converter.convert()
with gfile.GFile(flags.output_file, "wb") as f:
f.write(output_data)
def _convert_tf2_model(flags):
"""Calls function to convert the TensorFlow 2.0 model into a TFLite model.
Args:
flags: argparse.Namespace object.
Raises:
ValueError: Unsupported file format.
"""
# Load the model.
if flags.saved_model_dir:
converter = lite.TFLiteConverterV2.from_saved_model(
flags.saved_model_dir,
signature_keys=_parse_array(flags.saved_model_signature_key),
tags=_parse_set(flags.saved_model_tag_set))
elif flags.keras_model_file:
model = keras_deps.get_load_model_function()(flags.keras_model_file)
converter = lite.TFLiteConverterV2.from_keras_model(model)
converter.experimental_new_converter = flags.experimental_new_converter
if flags.experimental_new_quantizer is not None:
converter.experimental_new_quantizer = flags.experimental_new_quantizer
# Convert the model.
tflite_model = converter.convert()
with gfile.GFile(flags.output_file, "wb") as f:
f.write(tflite_model)
def _check_tf1_flags(flags, unparsed):
"""Checks the parsed and unparsed flags to ensure they are valid in 1.X.
Raises an error if previously support unparsed flags are found. Raises an
error for parsed flags that don't meet the required conditions.
Args:
flags: argparse.Namespace object containing TFLite flags.
unparsed: List of unparsed flags.
Raises:
ValueError: Invalid flags.
"""
# Check unparsed flags for common mistakes based on previous TOCO.
def _get_message_unparsed(flag, orig_flag, new_flag):
if flag.startswith(orig_flag):
return "\n Use {0} instead of {1}".format(new_flag, orig_flag)
return ""
if unparsed:
output = ""
for flag in unparsed:
output += _get_message_unparsed(flag, "--input_file", "--graph_def_file")
output += _get_message_unparsed(flag, "--savedmodel_directory",
"--saved_model_dir")
output += _get_message_unparsed(flag, "--std_value", "--std_dev_values")
output += _get_message_unparsed(flag, "--batch_size", "--input_shapes")
output += _get_message_unparsed(flag, "--dump_graphviz",
"--dump_graphviz_dir")
if output:
raise ValueError(output)
# Check that flags are valid.
if flags.graph_def_file and (not flags.input_arrays or
not flags.output_arrays):
raise ValueError("--input_arrays and --output_arrays are required with "
"--graph_def_file")
if flags.input_shapes:
if not flags.input_arrays:
raise ValueError("--input_shapes must be used with --input_arrays")
if flags.input_shapes.count(":") != flags.input_arrays.count(","):
raise ValueError("--input_shapes and --input_arrays must have the same "
"number of items")
if flags.std_dev_values or flags.mean_values:
if bool(flags.std_dev_values) != bool(flags.mean_values):
raise ValueError("--std_dev_values and --mean_values must be used "
"together")
if flags.std_dev_values.count(",") != flags.mean_values.count(","):
raise ValueError("--std_dev_values, --mean_values must have the same "
"number of items")
if (flags.default_ranges_min is None) != (flags.default_ranges_max is None):
raise ValueError("--default_ranges_min and --default_ranges_max must be "
"used together")
if flags.dump_graphviz_video and not flags.dump_graphviz_dir:
raise ValueError("--dump_graphviz_video must be used with "
"--dump_graphviz_dir")
if flags.custom_opdefs and not flags.experimental_new_converter:
raise ValueError("--custom_opdefs must be used with "
"--experimental_new_converter")
if flags.custom_opdefs and not flags.allow_custom_ops:
raise ValueError("--custom_opdefs must be used with --allow_custom_ops")
if (flags.experimental_select_user_tf_ops and
not flags.experimental_new_converter):
raise ValueError("--experimental_select_user_tf_ops must be used with "
"--experimental_new_converter")
def _check_tf2_flags(flags):
"""Checks the parsed and unparsed flags to ensure they are valid in 2.X.
Args:
flags: argparse.Namespace object containing TFLite flags.
Raises:
ValueError: Invalid flags.
"""
if not flags.keras_model_file and not flags.saved_model_dir:
raise ValueError("one of the arguments --saved_model_dir "
"--keras_model_file is required")
def _get_tf1_flags(parser):
"""Returns ArgumentParser for tflite_convert for TensorFlow 1.X.
Args:
parser: ArgumentParser
"""
# Input file flags.
input_file_group = parser.add_mutually_exclusive_group(required=True)
input_file_group.add_argument(
"--graph_def_file",
type=str,
help="Full filepath of file containing frozen TensorFlow GraphDef.")
input_file_group.add_argument(
"--saved_model_dir",
type=str,
help="Full filepath of directory containing the SavedModel.")
input_file_group.add_argument(
"--keras_model_file",
type=str,
help="Full filepath of HDF5 file containing tf.Keras model.")
# Model format flags.
parser.add_argument(
"--output_format",
type=str.upper,
choices=["TFLITE", "GRAPHVIZ_DOT"],
help="Output file format.")
parser.add_argument(
"--inference_type",
type=str.upper,
default="FLOAT",
help=("Target data type of real-number arrays in the output file. "
"Must be either FLOAT, INT8 or UINT8."))
parser.add_argument(
"--inference_input_type",
type=str.upper,
help=("Target data type of real-number input arrays. Allows for a "
"different type for input arrays in the case of quantization. "
"Must be either FLOAT, INT8 or UINT8."))
# Input and output arrays flags.
parser.add_argument(
"--input_arrays",
type=str,
help="Names of the input arrays, comma-separated.")
parser.add_argument(
"--input_shapes",
type=str,
help="Shapes corresponding to --input_arrays, colon-separated.")
parser.add_argument(
"--output_arrays",
type=str,
help="Names of the output arrays, comma-separated.")
# SavedModel related flags.
parser.add_argument(
"--saved_model_tag_set",
type=str,
help=("Comma-separated set of tags identifying the MetaGraphDef within "
"the SavedModel to analyze. All tags must be present. In order to "
"pass in an empty tag set, pass in \"\". (default \"serve\")"))
parser.add_argument(
"--saved_model_signature_key",
type=str,
help=("Key identifying the SignatureDef containing inputs and outputs. "
"(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)"))
# Quantization flags.
parser.add_argument(
"--std_dev_values",
type=str,
help=("Standard deviation of training data for each input tensor, "
"comma-separated floats. Used for quantized input tensors. "
"(default None)"))
parser.add_argument(
"--mean_values",
type=str,
help=("Mean of training data for each input tensor, comma-separated "
"floats. Used for quantized input tensors. (default None)"))
parser.add_argument(
"--default_ranges_min",
type=float,
help=("Default value for min bound of min/max range values used for all "
"arrays without a specified range, Intended for experimenting with "
"quantization via \"dummy quantization\". (default None)"))
parser.add_argument(
"--default_ranges_max",
type=float,
help=("Default value for max bound of min/max range values used for all "
"arrays without a specified range, Intended for experimenting with "
"quantization via \"dummy quantization\". (default None)"))
# quantize_weights is DEPRECATED.
parser.add_argument(
"--quantize_weights",
dest="post_training_quantize",
action="store_true",
help=argparse.SUPPRESS)
parser.add_argument(
"--post_training_quantize",
dest="post_training_quantize",
action="store_true",
help=(
"Boolean indicating whether to quantize the weights of the "
"converted float model. Model size will be reduced and there will "
"be latency improvements (at the cost of accuracy). (default False)"))
parser.add_argument(
"--quantize_to_float16",
dest="quantize_to_float16",
action="store_true",
help=("Boolean indicating whether to quantize weights to fp16 instead of "
"the default int8 when post-training quantization "
"(--post_training_quantize) is enabled. (default False)"))
# Graph manipulation flags.
parser.add_argument(
"--drop_control_dependency",
action="store_true",
help=("Boolean indicating whether to drop control dependencies silently. "
"This is due to TensorFlow not supporting control dependencies. "
"(default True)"))
parser.add_argument(
"--reorder_across_fake_quant",
action="store_true",
help=("Boolean indicating whether to reorder FakeQuant nodes in "
"unexpected locations. Used when the location of the FakeQuant "
"nodes is preventing graph transformations necessary to convert "
"the graph. Results in a graph that differs from the quantized "
"training graph, potentially causing differing arithmetic "
"behavior. (default False)"))
# Usage for this flag is --change_concat_input_ranges=true or
# --change_concat_input_ranges=false in order to make it clear what the flag
# is set to. This keeps the usage consistent with other usages of the flag
# where the default is different. The default value here is False.
parser.add_argument(
"--change_concat_input_ranges",
type=str.upper,
choices=["TRUE", "FALSE"],
help=("Boolean to change behavior of min/max ranges for inputs and "
"outputs of the concat operator for quantized models. Changes the "
"ranges of concat operator overlap when true. (default False)"))
# Permitted ops flags.
parser.add_argument(
"--allow_custom_ops",
action=_ParseBooleanFlag,
nargs="?",
help=("Boolean indicating whether to allow custom operations. When false "
"any unknown operation is an error. When true, custom ops are "
"created for any op that is unknown. The developer will need to "
"provide these to the TensorFlow Lite runtime with a custom "
"resolver. (default False)"))
parser.add_argument(
"--custom_opdefs",
type=str,
help=("String representing a list of custom ops OpDefs delineated with "
"commas that are included in the GraphDef. Required when using "
"custom operations with --experimental_new_converter."))
parser.add_argument(
"--target_ops",
type=str,
help=("Experimental flag, subject to change. Set of OpsSet options "
"indicating which converter to use. Options: {0}. One or more "
"option may be specified. (default set([OpsSet.TFLITE_BUILTINS]))"
"".format(",".join(lite.OpsSet.get_options()))))
parser.add_argument(
"--experimental_select_user_tf_ops",
type=str,
help=("Experimental flag, subject to change. Comma separated list of "
"user's defined TensorFlow operators required in the runtime."))
# Logging flags.
parser.add_argument(
"--dump_graphviz_dir",
type=str,
help=("Full filepath of folder to dump the graphs at various stages of "
"processing GraphViz .dot files. Preferred over --output_format="
"GRAPHVIZ_DOT in order to keep the requirements of the output "
"file."))
parser.add_argument(
"--dump_graphviz_video",
action="store_true",
help=("Boolean indicating whether to dump the graph after every graph "
"transformation"))
parser.add_argument(
"--conversion_summary_dir",
type=str,
help=("Full filepath to store the conversion logs, which includes "
"graphviz of the model before/after the conversion, an HTML report "
"and the conversion proto buffers. This will only be generated "
"when passing --experimental_new_converter"))
def _get_tf2_flags(parser):
"""Returns ArgumentParser for tflite_convert for TensorFlow 2.0.
Args:
parser: ArgumentParser
"""
# Input file flags.
input_file_group = parser.add_mutually_exclusive_group()
input_file_group.add_argument(
"--saved_model_dir",
type=str,
help="Full path of the directory containing the SavedModel.")
input_file_group.add_argument(
"--keras_model_file",
type=str,
help="Full filepath of HDF5 file containing tf.Keras model.")
# SavedModel related flags.
parser.add_argument(
"--saved_model_tag_set",
type=str,
help=("Comma-separated set of tags identifying the MetaGraphDef within "
"the SavedModel to analyze. All tags must be present. In order to "
"pass in an empty tag set, pass in \"\". (default \"serve\")"))
parser.add_argument(
"--saved_model_signature_key",
type=str,
help=("Key identifying the SignatureDef containing inputs and outputs. "
"(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)"))
# Enables 1.X converter in 2.X.
parser.add_argument(
"--enable_v1_converter",
action="store_true",
help=("Enables the TensorFlow V1 converter in 2.0"))
def _get_parser(use_v2_converter):
"""Returns an ArgumentParser for tflite_convert.
Args:
use_v2_converter: Indicates which converter to return.
Return: ArgumentParser.
"""
parser = argparse.ArgumentParser(
description=("Command line tool to run TensorFlow Lite Converter."))
# Output file flag.
parser.add_argument(
"--output_file",
type=str,
help="Full filepath of the output file.",
required=True)
if use_v2_converter:
_get_tf2_flags(parser)
else:
_get_tf1_flags(parser)
parser.add_argument(
"--experimental_new_converter",
action=_ParseBooleanFlag,
nargs="?",
default=True,
help=("Experimental flag, subject to change. Enables MLIR-based "
"conversion instead of TOCO conversion. (default True)"))
parser.add_argument(
"--experimental_new_quantizer",
action=_ParseBooleanFlag,
nargs="?",
help=("Experimental flag, subject to change. Enables MLIR-based "
"quantizer instead of flatbuffer conversion. (default True)"))
return parser
def run_main(_):
"""Main in tflite_convert.py."""
use_v2_converter = tf2.enabled()
parser = _get_parser(use_v2_converter)
tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:])
# If the user is running TensorFlow 2.X but has passed in enable_v1_converter
# then parse the flags again with the 1.X converter flags.
if tf2.enabled() and tflite_flags.enable_v1_converter:
use_v2_converter = False
parser = _get_parser(use_v2_converter)
tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:])
# Checks if the flags are valid.
try:
if use_v2_converter:
_check_tf2_flags(tflite_flags)
else:
_check_tf1_flags(tflite_flags, unparsed)
except ValueError as e:
parser.print_usage()
file_name = os.path.basename(sys.argv[0])
sys.stderr.write("{0}: error: {1}\n".format(file_name, str(e)))
sys.exit(1)
# Convert the model according to the user provided flag.
if use_v2_converter:
_convert_tf2_model(tflite_flags)
else:
try:
_convert_tf1_model(tflite_flags)
finally:
if tflite_flags.conversion_summary_dir:
if tflite_flags.experimental_new_converter:
gen_html.gen_conversion_log_html(tflite_flags.conversion_summary_dir,
tflite_flags.post_training_quantize,
tflite_flags.output_file)
else:
warnings.warn(
"Conversion summary will only be generated when enabling"
" the new converter via --experimental_new_converter. ")
def main():
app.run(main=run_main, argv=sys.argv[:1])
if __name__ == "__main__":
main()
@@ -0,0 +1,234 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Keras functions required by TensorFlow Lite.
The functions defined in this library have been copied over from Keras in order
to remove the dependency from TensorFlow Lite to Keras. The functions which
could not be copied over are accessed using the dependency inversion principle.
(for details, refer to tensorflow/python/util/keras_deps.py).
"""
import copy
from tensorflow.python.eager import def_function
from tensorflow.python.framework import tensor_spec
from tensorflow.python.util import keras_deps
from tensorflow.python.util import nest
from tensorflow.python.util.compat import collections_abc
def _enforce_names_consistency(specs):
"""Enforces that either all specs have names or none do."""
def _has_name(spec):
return hasattr(spec, 'name') and spec.name is not None
def _clear_name(spec):
spec = copy.deepcopy(spec)
if hasattr(spec, 'name'):
spec._name = None # pylint:disable=protected-access
return spec
flat_specs = nest.flatten(specs)
name_inconsistency = (
any(_has_name(s) for s in flat_specs) and
not all(_has_name(s) for s in flat_specs))
if name_inconsistency:
specs = nest.map_structure(_clear_name, specs)
return specs
def get_save_spec(model):
"""Returns the save spec of the subclassing keras model."""
shapes_dict = getattr(model, '_build_shapes_dict', None)
if not shapes_dict:
return None
if 'input_shape' not in shapes_dict:
raise ValueError(
'Model {} cannot be saved because the input shapes have not been set.'
)
input_shape = shapes_dict['input_shape']
if isinstance(input_shape, tuple):
shape = input_shape
shape = (None,) + shape[1:]
return tensor_spec.TensorSpec(
shape=shape, dtype=model.input_dtype
)
elif isinstance(input_shape, dict):
specs = {}
for key, shape in input_shape.items():
shape = (None,) + shape[1:]
specs[key] = tensor_spec.TensorSpec(
shape=shape, dtype=model.input_dtype, name=key
)
return specs
elif isinstance(input_shape, list):
specs = []
for shape in input_shape:
shape = (None,) + shape[1:]
specs.append(tensor_spec.TensorSpec(shape=shape, dtype=model.input_dtype))
return specs
def model_input_signature(model, keep_original_batch_size=False):
"""Inspect model to get its input signature.
The model's input signature is a list with a single (possibly-nested) object.
This is due to the Keras-enforced restriction that tensor inputs must be
passed in as the first argument.
For example, a model with input {'feature1': <Tensor>, 'feature2': <Tensor>}
will have input signature: [{'feature1': TensorSpec, 'feature2': TensorSpec}]
Args:
model: Keras Model object.
keep_original_batch_size: A boolean indicating whether we want to keep using
the original batch size or set it to None. Default is `False`, which means
that the batch dim of the returned input signature will always be set to
`None`.
Returns:
A list containing either a single TensorSpec or an object with nested
TensorSpecs. This list does not contain the `training` argument.
"""
if hasattr(model, 'save_spec'):
input_specs = model.save_spec(dynamic_batch=not keep_original_batch_size)
if input_specs is None:
return None
# The model's save spec returns (args, kwargs). Extract the first input arg
# to use as the input spec.
# TODO(b/188105669): Add support for multiple tensor arguments.
input_specs = input_specs[0][0]
else:
input_specs = model._get_save_spec( # pylint: disable=protected-access
dynamic_batch=not keep_original_batch_size)
if input_specs is None:
return None
input_specs = _enforce_names_consistency(input_specs)
# Return a list with a single element as the model's input signature.
if isinstance(input_specs,
collections_abc.Sequence) and len(input_specs) == 1:
# Note that the isinstance check filters out single-element dictionaries,
# which should also be wrapped as a single-element list.
return input_specs
else:
return [input_specs]
def raise_model_input_error(model):
raise ValueError(
'Model {} cannot be saved because the input shapes have not been '
'set. Usually, input shapes are automatically determined from calling'
' `.fit()` or `.predict()`. To manually set the shapes, call '
'`model.build(input_shape)`.'.format(model))
def _create_pseudo_names(tensors, prefix):
"""Creates pseudo {input | output} names for subclassed Models.
Warning: this function should only be used to define default
names for `Metics` and `SavedModel`. No other use cases should
rely on a `Model`'s input or output names.
Example with dict:
`{'a': [x1, x2], 'b': x3}` becomes:
`['a_1', 'a_2', 'b']`
Example with list:
`[x, y]` becomes:
`['output_1', 'output_2']`
Args:
tensors: `Model`'s outputs or inputs.
prefix: 'output_' for outputs, 'input_' for inputs.
Returns:
Flattened list of pseudo names.
"""
def one_index(ele):
# Start with "output_1" instead of "output_0".
if isinstance(ele, int):
return ele + 1
return ele
flat_paths = list(nest.yield_flat_paths(tensors))
flat_paths = nest.map_structure(one_index, flat_paths)
names = []
for path in flat_paths:
if not path:
name = prefix + '1' # Single output.
else:
name = '_'.join(str(p) for p in path)
if isinstance(path[0], int):
name = prefix + name
names.append(name)
return names
def create_pseudo_output_names(outputs):
"""Create pseudo output names for a subclassed Model."""
return _create_pseudo_names(outputs, prefix='output_')
def trace_model_call(model, input_signature=None):
"""Trace the model call to create a tf.function for exporting a Keras model.
Args:
model: A Keras model.
input_signature: optional, a list of tf.TensorSpec objects specifying the
inputs to the model.
Returns:
A tf.function wrapping the model's call function with input signatures set.
Raises:
ValueError: if input signature cannot be inferred from the model.
"""
if input_signature is None:
if isinstance(model.call, def_function.Function):
input_signature = model.call.input_signature
if input_signature is None:
input_signature = model_input_signature(model)
if input_signature is None:
raise_model_input_error(model)
@def_function.function(input_signature=input_signature, autograph=False)
def _wrapped_model(*args):
"""A concrete tf.function that wraps the model's call function."""
# When given a single input, Keras models will call the model on the tensor
# rather than a list consisting of the single tensor.
inputs = args[0] if len(input_signature) == 1 else list(args)
with keras_deps.get_call_context_function()().enter(
model,
inputs=inputs,
build_graph=False,
call_context_args={'training': False},
saving=True,
):
outputs = model(inputs, training=False)
return outputs
return _wrapped_model
@@ -0,0 +1,35 @@
from sys import modules
from types import ModuleType
def __update_globals(new_import_path, pywrap_m):
all_names = pywrap_m.__all__ if hasattr(pywrap_m, '__all__') else dir(
pywrap_m)
modules[new_import_path] = pywrap_m
for name in all_names:
sub_pywrap = getattr(pywrap_m, name)
if isinstance(sub_pywrap, ModuleType):
sub_name = sub_pywrap.__name__[len(pywrap_m.__name__):]
__update_globals(new_import_path + sub_name, sub_pywrap)
def __try_import():
imports_paths = ["litert.python._pywrap_string_util", "third_party.tensorflow.lite.python._pywrap_string_util", "tensorflow._pywrap_string_util", "tensorflow.python._pywrap_string_util"] # template_val
exceptions = []
last_exception = None
for import_path in imports_paths:
try:
pywrap_m = __import__(import_path, fromlist=["*"])
__update_globals(__name__, pywrap_m)
return
except ImportError as e:
exceptions.append(str(e))
last_exception = e
pass
raise RuntimeError(f"""
Could not import original test/binary location, import paths tried: {imports_paths}.
Previous exceptions: {exceptions}""", last_exception)
__try_import()
@@ -0,0 +1,265 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A utility class to generate the report HTML based on a common template."""
import io
import os
from tensorflow.lite.toco.logging import toco_conversion_log_pb2 as _toco_conversion_log_pb2
from tensorflow.python.lib.io import file_io as _file_io
from tensorflow.python.platform import resource_loader as _resource_loader
html_escape_table = {
"&": "&amp;",
'"': "&quot;",
"'": "&apos;",
">": "&gt;",
"<": "&lt;",
}
def html_escape(text):
return "".join(html_escape_table.get(c, c) for c in text)
def get_input_type_from_signature(op_signature):
"""Parses op_signature and returns a string denoting the input tensor type.
Args:
op_signature: a string specifying the signature of a particular operator.
The signature of an operator contains the input tensor's shape and type,
output tensor's shape and type, operator's name and its version. It has
the following schema:
INPUT:input_1_shape::input_1_type::input_2_shape::input_2_type::..
::OUTPUT:output_1_shape::output_1_type::output_2_shape::output_2_type::
..::NAME:operator_name ::VERSION:operator_version
An example of an operator signature is:
INPUT:[1,73,73,160]::float::[64,1,1,160]::float::[64]::float::
OUTPUT:[1,73,73,64]::float::NAME:Conv::VERSION:1
Returns:
A string denoting the input tensors' type. In the form of shape/type
separated
by comma. For example:
shape:[1,73,73,160],type:float,shape:[64,1,1,160],type:float,shape:[64],
type:float
"""
start = op_signature.find(":")
end = op_signature.find("::OUTPUT")
inputs = op_signature[start + 1:end]
lst = inputs.split("::")
out_str = ""
for i in range(len(lst)):
if i % 2 == 0:
out_str += "shape:"
else:
out_str += "type:"
out_str += lst[i]
out_str += ","
return out_str[:-1]
def get_operator_type(op_name, conversion_log):
if op_name in conversion_log.built_in_ops:
return "BUILT-IN"
elif op_name in conversion_log.custom_ops:
return "CUSTOM OP"
else:
return "SELECT OP"
class HTMLGenerator:
"""Utility class to generate an HTML report."""
def __init__(self, html_template_path, export_report_path):
"""Reads the HTML template content.
Args:
html_template_path: A string, path to the template HTML file.
export_report_path: A string, path to the generated HTML report. This path
should point to a '.html' file with date and time in its name.
e.g. 2019-01-01-10:05.toco_report.html.
Raises:
IOError: File doesn't exist.
"""
# Load the template HTML.
if not _file_io.file_exists(html_template_path):
raise IOError("File '{0}' does not exist.".format(html_template_path))
with _file_io.FileIO(html_template_path, "r") as f:
self.html_template = f.read()
_file_io.recursive_create_dir(os.path.dirname(export_report_path))
self.export_report_path = export_report_path
def generate(self,
toco_conversion_log_before,
toco_conversion_log_after,
post_training_quant_enabled,
dot_before,
dot_after,
toco_err_log="",
tflite_graph_path=""):
"""Generates the HTML report and writes it to local directory.
This function uses the fields in `toco_conversion_log_before` and
`toco_conversion_log_after` to populate the HTML content. Certain markers
(placeholders) in the HTML template are then substituted with the fields
from the protos. Once finished it will write the HTML file to the specified
local file path.
Args:
toco_conversion_log_before: A `TocoConversionLog` protobuf generated
before the model is converted by TOCO.
toco_conversion_log_after: A `TocoConversionLog` protobuf generated after
the model is converted by TOCO.
post_training_quant_enabled: A boolean, whether post-training quantization
is enabled.
dot_before: A string, the dot representation of the model
before the conversion.
dot_after: A string, the dot representation of the model after
the conversion.
toco_err_log: A string, the logs emitted by TOCO during conversion. Caller
need to ensure that this string is properly anonymized (any kind of
user data should be eliminated).
tflite_graph_path: A string, the filepath to the converted TFLite model.
Raises:
RuntimeError: When error occurs while generating the template.
"""
html_dict = {}
html_dict["<!--CONVERSION_STATUS-->"] = (
r'<span class="label label-danger">Fail</span>'
) if toco_err_log else r'<span class="label label-success">Success</span>'
html_dict["<!--TOTAL_OPS_BEFORE_CONVERT-->"] = str(
toco_conversion_log_before.model_size)
html_dict["<!--TOTAL_OPS_AFTER_CONVERT-->"] = str(
toco_conversion_log_after.model_size)
html_dict["<!--BUILT_IN_OPS_COUNT-->"] = str(
sum(toco_conversion_log_after.built_in_ops.values()))
html_dict["<!--SELECT_OPS_COUNT-->"] = str(
sum(toco_conversion_log_after.select_ops.values()))
html_dict["<!--CUSTOM_OPS_COUNT-->"] = str(
sum(toco_conversion_log_after.custom_ops.values()))
html_dict["<!--POST_TRAINING_QUANT_ENABLED-->"] = (
"is" if post_training_quant_enabled else "isn't")
pre_op_profile = ""
post_op_profile = ""
# Generate pre-conversion op profiles as a list of HTML table rows.
for i in range(len(toco_conversion_log_before.op_list)):
# Append operator name column.
pre_op_profile += "<tr><td>" + toco_conversion_log_before.op_list[
i] + "</td>"
# Append input type column.
if i < len(toco_conversion_log_before.op_signatures):
pre_op_profile += "<td>" + get_input_type_from_signature(
toco_conversion_log_before.op_signatures[i]) + "</td></tr>"
else:
pre_op_profile += "<td></td></tr>"
# Generate post-conversion op profiles as a list of HTML table rows.
for op in toco_conversion_log_after.op_list:
supported_type = get_operator_type(op, toco_conversion_log_after)
post_op_profile += ("<tr><td>" + op + "</td><td>" + supported_type +
"</td></tr>")
html_dict["<!--REPEAT_TABLE1_ROWS-->"] = pre_op_profile
html_dict["<!--REPEAT_TABLE2_ROWS-->"] = post_op_profile
html_dict["<!--DOT_BEFORE_CONVERT-->"] = dot_before
html_dict["<!--DOT_AFTER_CONVERT-->"] = dot_after
if toco_err_log:
html_dict["<!--TOCO_INFO_LOG-->"] = html_escape(toco_err_log)
else:
success_info = ("TFLite graph conversion successful. You can preview the "
"converted model at: ") + tflite_graph_path
html_dict["<!--TOCO_INFO_LOG-->"] = html_escape(success_info)
# Replace each marker (as keys of html_dict) with the actual text (as values
# of html_dict) in the HTML template string.
template = self.html_template
for marker in html_dict:
template = template.replace(marker, html_dict[marker], 1)
# Check that the marker text is replaced.
if template.find(marker) != -1:
raise RuntimeError("Could not populate marker text %r" % marker)
with _file_io.FileIO(self.export_report_path, "w") as f:
f.write(template)
def gen_conversion_log_html(conversion_log_dir, quantization_enabled,
tflite_graph_path):
"""Generates an HTML report about the conversion process.
Args:
conversion_log_dir: A string specifying the file directory of the conversion
logs. It's required that before calling this function, the
`conversion_log_dir`
already contains the following files: `toco_log_before.pb`,
`toco_log_after.pb`, `toco_tf_graph.dot`,
`toco_tflite_graph.dot`.
quantization_enabled: A boolean, passed from the tflite converter to
indicate whether post-training quantization is enabled during conversion.
tflite_graph_path: A string, the filepath to the converted TFLite model.
Raises:
IOError: When any of the required files doesn't exist.
"""
template_filename = _resource_loader.get_path_to_datafile("template.html")
if not os.path.exists(template_filename):
raise IOError("Failed to generate HTML: file '{0}' doesn't exist.".format(
template_filename))
toco_log_before_path = os.path.join(conversion_log_dir, "toco_log_before.pb")
toco_log_after_path = os.path.join(conversion_log_dir, "toco_log_after.pb")
dot_before_path = os.path.join(conversion_log_dir, "toco_tf_graph.dot")
dot_after_path = os.path.join(conversion_log_dir, "toco_tflite_graph.dot")
if not os.path.exists(toco_log_before_path):
raise IOError("Failed to generate HTML: file '{0}' doesn't exist.".format(
toco_log_before_path))
if not os.path.exists(toco_log_after_path):
raise IOError("Failed to generate HTML: file '{0}' doesn't exist.".format(
toco_log_after_path))
if not os.path.exists(dot_before_path):
raise IOError("Failed to generate HTML: file '{0}' doesn't exist.".format(
dot_before_path))
if not os.path.exists(dot_after_path):
raise IOError("Failed to generate HTML: file '{0}' doesn't exist.".format(
dot_after_path))
html_generator = HTMLGenerator(
template_filename,
os.path.join(conversion_log_dir, "toco_conversion_summary.html"))
# Parse the generated `TocoConversionLog`.
toco_conversion_log_before = _toco_conversion_log_pb2.TocoConversionLog()
toco_conversion_log_after = _toco_conversion_log_pb2.TocoConversionLog()
with open(toco_log_before_path, "rb") as f:
toco_conversion_log_before.ParseFromString(f.read())
with open(toco_log_after_path, "rb") as f:
toco_conversion_log_after.ParseFromString(f.read())
# Read the dot file before/after the conversion.
with io.open(dot_before_path, "r", encoding="utf-8") as f:
dot_before = f.read().rstrip()
with io.open(dot_after_path, "r", encoding="utf-8") as f:
dot_after = f.read().rstrip()
html_generator.generate(toco_conversion_log_before, toco_conversion_log_after,
quantization_enabled, dot_before, dot_after,
toco_conversion_log_after.toco_err_logs,
tflite_graph_path)
@@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: tensorflow/lite/toco/logging/toco_conversion_log.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'tensorflow/lite/toco/logging/toco_conversion_log.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n6tensorflow/lite/toco/logging/toco_conversion_log.proto\x12\x04toco\"\xc9\x04\n\x11TocoConversionLog\x12\x0f\n\x07op_list\x18\x01 \x03(\t\x12=\n\x0c\x62uilt_in_ops\x18\x02 \x03(\x0b\x32\'.toco.TocoConversionLog.BuiltInOpsEntry\x12:\n\ncustom_ops\x18\x03 \x03(\x0b\x32&.toco.TocoConversionLog.CustomOpsEntry\x12:\n\nselect_ops\x18\x04 \x03(\x0b\x32&.toco.TocoConversionLog.SelectOpsEntry\x12\x15\n\rop_signatures\x18\x05 \x03(\t\x12\x1a\n\x12input_tensor_types\x18\x06 \x03(\t\x12\x1b\n\x13output_tensor_types\x18\x07 \x03(\t\x12\x19\n\x11log_generation_ts\x18\x08 \x01(\x03\x12\x12\n\nmodel_size\x18\t \x01(\x05\x12\x17\n\x0ftf_lite_version\x18\n \x01(\t\x12\x12\n\nos_version\x18\x0b \x01(\t\x12\x12\n\nmodel_hash\x18\x0c \x01(\t\x12\x15\n\rtoco_err_logs\x18\r \x01(\t\x1a\x31\n\x0f\x42uiltInOpsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x05:\x02\x38\x01\x1a\x30\n\x0e\x43ustomOpsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x05:\x02\x38\x01\x1a\x30\n\x0eSelectOpsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x05:\x02\x38\x01')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.lite.toco.logging.toco_conversion_log_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TOCOCONVERSIONLOG_BUILTINOPSENTRY']._loaded_options = None
_globals['_TOCOCONVERSIONLOG_BUILTINOPSENTRY']._serialized_options = b'8\001'
_globals['_TOCOCONVERSIONLOG_CUSTOMOPSENTRY']._loaded_options = None
_globals['_TOCOCONVERSIONLOG_CUSTOMOPSENTRY']._serialized_options = b'8\001'
_globals['_TOCOCONVERSIONLOG_SELECTOPSENTRY']._loaded_options = None
_globals['_TOCOCONVERSIONLOG_SELECTOPSENTRY']._serialized_options = b'8\001'
_globals['_TOCOCONVERSIONLOG']._serialized_start=65
_globals['_TOCOCONVERSIONLOG']._serialized_end=650
_globals['_TOCOCONVERSIONLOG_BUILTINOPSENTRY']._serialized_start=501
_globals['_TOCOCONVERSIONLOG_BUILTINOPSENTRY']._serialized_end=550
_globals['_TOCOCONVERSIONLOG_CUSTOMOPSENTRY']._serialized_start=552
_globals['_TOCOCONVERSIONLOG_CUSTOMOPSENTRY']._serialized_end=600
_globals['_TOCOCONVERSIONLOG_SELECTOPSENTRY']._serialized_start=602
_globals['_TOCOCONVERSIONLOG_SELECTOPSENTRY']._serialized_end=650
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: tensorflow/lite/toco/toco_flags.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'tensorflow/lite/toco/toco_flags.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
from tensorflow.compiler.mlir.lite.debug import debug_options_pb2 as tensorflow_dot_compiler_dot_mlir_dot_lite_dot_debug_dot_debug__options__pb2
from tensorflow.compiler.mlir.quantization.stablehlo import quantization_config_pb2 as tensorflow_dot_compiler_dot_mlir_dot_quantization_dot_stablehlo_dot_quantization__config__pb2
from tensorflow.compiler.mlir.quantization.stablehlo import quantization_options_pb2 as tensorflow_dot_compiler_dot_mlir_dot_quantization_dot_stablehlo_dot_quantization__options__pb2
from tensorflow.lite.toco import types_pb2 as tensorflow_dot_lite_dot_toco_dot_types__pb2
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n%tensorflow/lite/toco/toco_flags.proto\x12\x04toco\x1a\x37tensorflow/compiler/mlir/lite/debug/debug_options.proto\x1aItensorflow/compiler/mlir/quantization/stablehlo/quantization_config.proto\x1aJtensorflow/compiler/mlir/quantization/stablehlo/quantization_options.proto\x1a tensorflow/lite/toco/types.proto\"\x94\x14\n\tTocoFlags\x12&\n\x0cinput_format\x18\x01 \x01(\x0e\x32\x10.toco.FileFormat\x12\'\n\routput_format\x18\x02 \x01(\x0e\x32\x10.toco.FileFormat\x12.\n\x14inference_input_type\x18\x0b \x01(\x0e\x32\x10.toco.IODataType\x12(\n\x0einference_type\x18\x04 \x01(\x0e\x32\x10.toco.IODataType\x12\x1a\n\x12\x64\x65\x66\x61ult_ranges_min\x18\x05 \x01(\x02\x12\x1a\n\x12\x64\x65\x66\x61ult_ranges_max\x18\x06 \x01(\x02\x12 \n\x18\x64\x65\x66\x61ult_int16_ranges_min\x18\x0f \x01(\x02\x12 \n\x18\x64\x65\x66\x61ult_int16_ranges_max\x18\x10 \x01(\x02\x12\x17\n\x0f\x64rop_fake_quant\x18\x07 \x01(\x08\x12!\n\x19reorder_across_fake_quant\x18\x08 \x01(\x08\x12\x18\n\x10\x61llow_custom_ops\x18\n \x01(\x08\x12\x1f\n\x17\x64rop_control_dependency\x18\x0c \x01(\x08\x12+\n#debug_disable_recurrent_cell_fusion\x18\r \x01(\x08\x12%\n\x1dpropagate_fake_quant_num_bits\x18\x0e \x01(\x08\x12\x35\n-allow_nudging_weights_to_use_fast_gemm_kernel\x18\x11 \x01(\x08\x12\'\n\x1b\x64\x65\x64upe_array_min_size_bytes\x18\x12 \x01(\x03:\x02\x36\x34\x12&\n\x18split_tflite_lstm_inputs\x18\x13 \x01(\x08:\x04true\x12\x1f\n\x10quantize_weights\x18\x14 \x01(\x08:\x05\x66\x61lse\x12\x19\n\x11\x64ump_graphviz_dir\x18\x18 \x01(\t\x12#\n\x1b\x64ump_graphviz_include_video\x18\x19 \x01(\x08\x12%\n\x16post_training_quantize\x18\x1a \x01(\x08:\x05\x66\x61lse\x12#\n\x14\x65nable_select_tf_ops\x18\x1b \x01(\x08:\x05\x66\x61lse\x12\"\n\x13\x66orce_select_tf_ops\x18\x1c \x01(\x08:\x05\x66\x61lse\x12\"\n\x13quantize_to_float16\x18\x1d \x01(\x08:\x05\x66\x61lse\x12#\n\x15\x61llow_dynamic_tensors\x18\x1e \x01(\x08:\x04true\x12\x1e\n\x16\x63onversion_summary_dir\x18\x1f \x01(\t\x12\x19\n\rcustom_opdefs\x18 \x03(\tB\x02\x18\x01\x12\x1a\n\x12select_user_tf_ops\x18! \x03(\t\x12.\n enable_tflite_resource_variables\x18\" \x01(\x08:\x04true\x12!\n\x12unfold_batchmatmul\x18# \x01(\x08:\x05\x66\x61lse\x12#\n\x15lower_tensor_list_ops\x18$ \x01(\x08:\x04true\x12+\n\x11\x61\x63\x63umulation_type\x18% \x01(\x0e\x32\x10.toco.IODataType\x12\x1d\n\x0e\x61llow_bfloat16\x18& \x01(\x08:\x05\x66\x61lse\x12\x1f\n\x17\x61llow_all_select_tf_ops\x18\' \x01(\x08\x12*\n\x1bunfold_large_splat_constant\x18( \x01(\x08:\x05\x66\x61lse\x12\x1a\n\x12supported_backends\x18) \x03(\t\x12\x39\n*default_to_single_batch_in_tensor_list_ops\x18* \x01(\x08:\x05\x66\x61lse\x12/\n disable_per_channel_quantization\x18+ \x01(\x08:\x05\x66\x61lse\x12\x32\n#enable_mlir_dynamic_range_quantizer\x18, \x01(\x08:\x05\x66\x61lse\x12\x1c\n\x14tf_quantization_mode\x18- \x01(\t\x12)\n\x1a\x64isable_infer_tensor_range\x18. \x01(\x08:\x05\x66\x61lse\x12&\n\x17use_fake_quant_num_bits\x18/ \x01(\x08:\x05\x66\x61lse\x12*\n\x1b\x65nable_dynamic_update_slice\x18\x30 \x01(\x08:\x05\x66\x61lse\x12!\n\x12preserve_assert_op\x18\x31 \x01(\x08:\x05\x66\x61lse\x12*\n\x1bguarantee_all_funcs_one_use\x18\x32 \x01(\x08:\x05\x66\x61lse\x12#\n\x14\x63onvert_to_stablehlo\x18\x33 \x01(\x08:\x05\x66\x61lse\x12\x30\n!enable_mlir_variable_quantization\x18\x34 \x01(\x08:\x05\x66\x61lse\x12&\n\x17\x64isable_fuse_mul_and_fc\x18\x35 \x01(\x08:\x05\x66\x61lse\x12M\n\x14quantization_options\x18\x36 \x01(\x0b\x32+.stablehlo.quantization.QuantizationOptionsB\x02\x18\x01\x12.\n\x1b\x65nable_hlo_to_tf_conversion\x18\x37 \x01(\x08:\x05\x66\x61lseB\x02\x18\x01\x12\x39\n\rdebug_options\x18\x38 \x01(\x0b\x32\".tensorflow.converter.DebugOptions\x12 \n\x11use_buffer_offset\x18\x39 \x01(\x08:\x05\x66\x61lse\x12.\n\x1flegalize_custom_tensor_list_ops\x18: \x01(\x08:\x05\x66\x61lse\x12$\n\x15reduce_type_precision\x18; \x01(\x08:\x05\x66\x61lse\x12!\n\x13qdq_conversion_mode\x18< \x01(\t:\x04NONE\x12G\n\x13quantization_config\x18= \x01(\x0b\x32*.stablehlo.quantization.QuantizationConfig\x12@\n1disable_per_channel_quantization_for_dense_layers\x18> \x01(\x08:\x05\x66\x61lse\x12/\n enable_composite_direct_lowering\x18? \x01(\x08:\x05\x66\x61lse\x12K\n\x16model_origin_framework\x18@ \x01(\x0e\x32$.toco.TocoFlags.ModelOriginFramework:\x05UNSET\x12\x32\n#canonicalizing_inf_as_min_max_float\x18\x41 \x01(\x08:\x05\x66\x61lse\"R\n\x14ModelOriginFramework\x12\t\n\x05UNSET\x10\x00\x12\x0e\n\nTENSORFLOW\x10\x01\x12\t\n\x05KERAS\x10\x02\x12\x07\n\x03JAX\x10\x03\x12\x0b\n\x07PYTORCH\x10\x04*\\\n\nFileFormat\x12\x17\n\x13\x46ILE_FORMAT_UNKNOWN\x10\x00\x12\x17\n\x13TENSORFLOW_GRAPHDEF\x10\x01\x12\n\n\x06TFLITE\x10\x02\x12\x10\n\x0cGRAPHVIZ_DOT\x10\x03')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.lite.toco.toco_flags_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_TOCOFLAGS'].fields_by_name['custom_opdefs']._loaded_options = None
_globals['_TOCOFLAGS'].fields_by_name['custom_opdefs']._serialized_options = b'\030\001'
_globals['_TOCOFLAGS'].fields_by_name['quantization_options']._loaded_options = None
_globals['_TOCOFLAGS'].fields_by_name['quantization_options']._serialized_options = b'\030\001'
_globals['_TOCOFLAGS'].fields_by_name['enable_hlo_to_tf_conversion']._loaded_options = None
_globals['_TOCOFLAGS'].fields_by_name['enable_hlo_to_tf_conversion']._serialized_options = b'\030\001'
_globals['_FILEFORMAT']._serialized_start=2872
_globals['_FILEFORMAT']._serialized_end=2964
_globals['_TOCOFLAGS']._serialized_start=290
_globals['_TOCOFLAGS']._serialized_end=2870
_globals['_TOCOFLAGS_MODELORIGINFRAMEWORK']._serialized_start=2788
_globals['_TOCOFLAGS_MODELORIGINFRAMEWORK']._serialized_end=2870
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,36 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: tensorflow/lite/toco/types.proto
# Protobuf Python Version: 5.28.3
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
28,
3,
'',
'tensorflow/lite/toco/types.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n tensorflow/lite/toco/types.proto\x12\x04toco*\xb3\x02\n\nIODataType\x12\x18\n\x14IO_DATA_TYPE_UNKNOWN\x10\x00\x12\t\n\x05\x46LOAT\x10\x01\x12\x13\n\x0fQUANTIZED_UINT8\x10\x02\x12\t\n\x05INT32\x10\x03\x12\t\n\x05INT64\x10\x04\x12\n\n\x06STRING\x10\x05\x12\x13\n\x0fQUANTIZED_INT16\x10\x06\x12\x08\n\x04\x42OOL\x10\x07\x12\r\n\tCOMPLEX64\x10\x08\x12\x12\n\x0eQUANTIZED_INT8\x10\t\x12\x0b\n\x07\x46LOAT16\x10\n\x12\x0b\n\x07\x46LOAT64\x10\x0b\x12\x0e\n\nCOMPLEX128\x10\x0c\x12\n\n\x06UINT64\x10\r\x12\x0c\n\x08RESOURCE\x10\x0e\x12\x0b\n\x07VARIANT\x10\x0f\x12\n\n\x06UINT32\x10\x10\x12\t\n\x05UINT8\x10\x11\x12\x08\n\x04INT8\x10\x12\x12\t\n\x05INT16\x10\x13\x12\n\n\x06UINT16\x10\x14')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'tensorflow.lite.toco.types_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_IODATATYPE']._serialized_start=43
_globals['_IODATATYPE']._serialized_end=350
# @@protoc_insertion_point(module_scope)
@@ -0,0 +1,517 @@
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions for FlatBuffers.
All functions that are commonly used to work with FlatBuffers.
Refer to the tensorflow lite flatbuffer schema here:
tensorflow/lite/schema/schema.fbs
"""
import copy
import random
import re
import struct
import sys
from typing import Optional, Type, TypeVar, Union
import flatbuffers
from tensorflow.lite.python import schema_py_generated as schema_fb
from tensorflow.lite.python import schema_util
from tensorflow.python.platform import gfile
_TFLITE_FILE_IDENTIFIER = b'TFL3'
def convert_bytearray_to_object(model_bytearray):
"""Converts a tflite model from a bytearray to an object for parsing."""
model_object = schema_fb.Model.GetRootAsModel(model_bytearray, 0)
return schema_fb.ModelT.InitFromObj(model_object)
def read_model(input_tflite_file):
"""Reads a tflite model as a python object.
Args:
input_tflite_file: Full path name to the input tflite file
Raises:
RuntimeError: If input_tflite_file path is invalid.
IOError: If input_tflite_file cannot be opened.
Returns:
A python object corresponding to the input tflite file.
"""
if not gfile.Exists(input_tflite_file):
raise RuntimeError('Input file not found at %r\n' % input_tflite_file)
with gfile.GFile(input_tflite_file, 'rb') as input_file_handle:
model_bytearray = bytearray(input_file_handle.read())
return read_model_from_bytearray(model_bytearray)
def read_model_from_bytearray(model_bytearray):
"""Reads a tflite model as a python object.
Args:
model_bytearray: TFLite model in bytearray format.
Returns:
A python object corresponding to the input tflite file.
"""
model = convert_bytearray_to_object(model_bytearray)
if sys.byteorder == 'big':
byte_swap_tflite_model_obj(model, 'little', 'big')
# Offset handling for models > 2GB
for buffer in model.buffers:
if buffer.offset:
buffer.data = model_bytearray[buffer.offset : buffer.offset + buffer.size]
buffer.offset = 0
buffer.size = 0
for subgraph in model.subgraphs:
for op in subgraph.operators:
if op.largeCustomOptionsOffset:
op.customOptions = model_bytearray[
op.largeCustomOptionsOffset : op.largeCustomOptionsOffset
+ op.largeCustomOptionsSize
]
op.largeCustomOptionsOffset = 0
op.largeCustomOptionsSize = 0
return model
def read_model_with_mutable_tensors(input_tflite_file):
"""Reads a tflite model as a python object with mutable tensors.
Similar to read_model() with the addition that the returned object has
mutable tensors (read_model() returns an object with immutable tensors).
NOTE: This API only works for TFLite generated with
_experimental_use_buffer_offset=false
Args:
input_tflite_file: Full path name to the input tflite file
Raises:
RuntimeError: If input_tflite_file path is invalid.
IOError: If input_tflite_file cannot be opened.
Returns:
A mutable python object corresponding to the input tflite file.
"""
return copy.deepcopy(read_model(input_tflite_file))
def convert_object_to_bytearray(model_object, extra_buffer=b''):
"""Converts a tflite model from an object to a immutable bytearray."""
# Initial size of the buffer, which will grow automatically if needed
builder = flatbuffers.Builder(1024)
model_offset = model_object.Pack(builder)
builder.Finish(model_offset, file_identifier=_TFLITE_FILE_IDENTIFIER)
model_bytearray = bytes(builder.Output())
model_bytearray = model_bytearray + extra_buffer
return model_bytearray
def write_model(model_object, output_tflite_file):
"""Writes the tflite model, a python object, into the output file.
NOTE: This API only works for TFLite generated with
_experimental_use_buffer_offset=false
Args:
model_object: A tflite model as a python object
output_tflite_file: Full path name to the output tflite file.
Raises:
IOError: If output_tflite_file path is invalid or cannot be opened.
"""
if sys.byteorder == 'big':
model_object = copy.deepcopy(model_object)
byte_swap_tflite_model_obj(model_object, 'big', 'little')
model_bytearray = convert_object_to_bytearray(model_object)
with gfile.GFile(output_tflite_file, 'wb') as output_file_handle:
output_file_handle.write(model_bytearray)
def strip_strings(model):
"""Strips all nonessential strings from the model to reduce model size.
We remove the following strings:
(find strings by searching ":string" in the tensorflow lite flatbuffer schema)
1. Model description
2. SubGraph name
3. Tensor names
We retain OperatorCode custom_code and Metadata name.
Args:
model: The model from which to remove nonessential strings.
"""
model.description = None
for subgraph in model.subgraphs:
subgraph.name = None
for tensor in subgraph.tensors:
tensor.name = None
# We clear all signature_def structure, since without names it is useless.
model.signatureDefs = None
def type_to_name(tensor_type):
"""Converts a numerical enum to a readable tensor type."""
for name, value in schema_fb.TensorType.__dict__.items():
if value == tensor_type:
return name
return None
def randomize_weights(model, random_seed=0, buffers_to_skip=None):
"""Randomize weights in a model.
Args:
model: The model in which to randomize weights.
random_seed: The input to the random number generator (default value is 0).
buffers_to_skip: The list of buffer indices to skip. The weights in these
buffers are left unmodified.
"""
# The input to the random seed generator. The default value is 0.
random.seed(random_seed)
# Parse model buffers which store the model weights
buffers = model.buffers
buffer_ids = range(1, len(buffers)) # ignore index 0 as it's always None
if buffers_to_skip is not None:
buffer_ids = [idx for idx in buffer_ids if idx not in buffers_to_skip]
buffer_types = {}
for graph in model.subgraphs:
for op in graph.operators:
if op.inputs is None:
break
for input_idx in op.inputs:
tensor = graph.tensors[input_idx]
buffer_types[tensor.buffer] = type_to_name(tensor.type)
for i in buffer_ids:
buffer_i_data = buffers[i].data
buffer_i_size = 0 if buffer_i_data is None else buffer_i_data.size
if buffer_i_size == 0:
continue
# Raw data buffers are of type ubyte (or uint8) whose values lie in the
# range [0, 255]. Those ubytes (or unint8s) are the underlying
# representation of each datatype. For example, a bias tensor of type
# int32 appears as a buffer 4 times it's length of type ubyte (or uint8).
# For floats, we need to generate a valid float and then pack it into
# the raw bytes in place.
buffer_type = buffer_types.get(i, 'INT8')
if buffer_type.startswith('FLOAT'):
format_code = 'e' if buffer_type == 'FLOAT16' else 'f'
for offset in range(0, buffer_i_size, struct.calcsize(format_code)):
value = random.uniform(-0.5, 0.5) # See http://b/152324470#comment2
struct.pack_into(format_code, buffer_i_data, offset, value)
else:
for j in range(buffer_i_size):
buffer_i_data[j] = random.randint(0, 255)
def rename_custom_ops(model, map_custom_op_renames):
"""Rename custom ops so they use the same naming style as builtin ops.
Args:
model: The input tflite model.
map_custom_op_renames: A mapping from old to new custom op names.
"""
for op_code in model.operatorCodes:
if op_code.customCode:
op_code_str = op_code.customCode.decode('ascii')
if op_code_str in map_custom_op_renames:
op_code.customCode = map_custom_op_renames[op_code_str].encode('ascii')
def opcode_to_name(model, op_code):
"""Converts a TFLite op_code to the human readable name.
Args:
model: The input tflite model.
op_code: The op_code to resolve to a readable name.
Returns:
A string containing the human readable op name, or None if not resolvable.
"""
op = model.operatorCodes[op_code]
code = max(op.builtinCode, op.deprecatedBuiltinCode)
for name, value in vars(schema_fb.BuiltinOperator).items():
if value == code:
return name
return None
def xxd_output_to_bytes(input_cc_file):
"""Converts xxd output C++ source file to bytes (immutable).
Args:
input_cc_file: Full path name to th C++ source file dumped by xxd
Raises:
RuntimeError: If input_cc_file path is invalid.
IOError: If input_cc_file cannot be opened.
Returns:
A bytearray corresponding to the input cc file array.
"""
# Match hex values in the string with comma as separator
pattern = re.compile(r'\W*(0x[0-9a-fA-F,x ]+).*')
model_bytearray = bytearray()
with open(input_cc_file) as file_handle:
for line in file_handle:
values_match = pattern.match(line)
if values_match is None:
continue
# Match in the parentheses (hex array only)
list_text = values_match.group(1)
# Extract hex values (text) from the line
# e.g. 0x1c, 0x00, 0x00, 0x00, 0x54, 0x46, 0x4c,
values_text = filter(None, list_text.split(','))
# Convert to hex
values = [int(x, base=16) for x in values_text]
model_bytearray.extend(values)
return bytes(model_bytearray)
def xxd_output_to_object(input_cc_file):
"""Converts xxd output C++ source file to object.
Args:
input_cc_file: Full path name to th C++ source file dumped by xxd
Raises:
RuntimeError: If input_cc_file path is invalid.
IOError: If input_cc_file cannot be opened.
Returns:
A python object corresponding to the input tflite file.
"""
model_bytes = xxd_output_to_bytes(input_cc_file)
return convert_bytearray_to_object(model_bytes)
def byte_swap_buffer_content(buffer, chunksize, from_endiness, to_endiness):
"""Helper function for byte-swapping the buffers field."""
to_swap = [
buffer.data[i : i + chunksize]
for i in range(0, len(buffer.data), chunksize)
]
buffer.data = b''.join([
int.from_bytes(byteswap, from_endiness).to_bytes(chunksize, to_endiness)
for byteswap in to_swap
])
def byte_swap_string_content(buffer, from_endiness, to_endiness):
"""Helper function for byte-swapping the string buffer.
Args:
buffer: TFLite string buffer of from_endiness format.
from_endiness: The original endianness format of the string buffer.
to_endiness: The destined endianness format of the string buffer.
"""
num_of_strings = int.from_bytes(buffer.data[0:4], from_endiness)
string_content = bytearray(buffer.data[4 * (num_of_strings + 2) :])
prefix_data = b''.join([
int.from_bytes(buffer.data[i : i + 4], from_endiness).to_bytes(
4, to_endiness
)
for i in range(0, (num_of_strings + 1) * 4 + 1, 4)
])
buffer.data = prefix_data + string_content
def byte_swap_tflite_model_obj(model, from_endiness, to_endiness):
"""Byte swaps the buffers field in a TFLite model.
Args:
model: TFLite model object of from_endiness format.
from_endiness: The original endianness format of the buffers in model.
to_endiness: The destined endianness format of the buffers in model.
"""
if model is None:
return
# Get all the constant buffers, byte swapping them as per their data types
buffer_swapped = []
types_of_16_bits = [
schema_fb.TensorType.FLOAT16,
schema_fb.TensorType.INT16,
schema_fb.TensorType.UINT16,
]
types_of_32_bits = [
schema_fb.TensorType.FLOAT32,
schema_fb.TensorType.INT32,
schema_fb.TensorType.COMPLEX64,
schema_fb.TensorType.UINT32,
]
types_of_64_bits = [
schema_fb.TensorType.INT64,
schema_fb.TensorType.FLOAT64,
schema_fb.TensorType.COMPLEX128,
schema_fb.TensorType.UINT64,
]
for subgraph in model.subgraphs:
for tensor in subgraph.tensors:
if (
tensor.buffer > 0
and tensor.buffer < len(model.buffers)
and tensor.buffer not in buffer_swapped
and model.buffers[tensor.buffer].data is not None
):
if tensor.type == schema_fb.TensorType.STRING:
byte_swap_string_content(
model.buffers[tensor.buffer], from_endiness, to_endiness
)
elif tensor.type in types_of_16_bits:
byte_swap_buffer_content(
model.buffers[tensor.buffer], 2, from_endiness, to_endiness
)
elif tensor.type in types_of_32_bits:
byte_swap_buffer_content(
model.buffers[tensor.buffer], 4, from_endiness, to_endiness
)
elif tensor.type in types_of_64_bits:
byte_swap_buffer_content(
model.buffers[tensor.buffer], 8, from_endiness, to_endiness
)
else:
continue
buffer_swapped.append(tensor.buffer)
def byte_swap_tflite_buffer(tflite_model, from_endiness, to_endiness):
"""Generates a new model byte array after byte swapping its buffers field.
Args:
tflite_model: TFLite flatbuffer in a byte array.
from_endiness: The original endianness format of the buffers in
tflite_model.
to_endiness: The destined endianness format of the buffers in tflite_model.
Returns:
TFLite flatbuffer in a byte array, after being byte swapped to to_endiness
format.
"""
if tflite_model is None:
return None
# Load TFLite Flatbuffer byte array into an object.
model = convert_bytearray_to_object(tflite_model)
# Byte swapping the constant buffers as per their data types
byte_swap_tflite_model_obj(model, from_endiness, to_endiness)
# Return a TFLite flatbuffer as a byte array.
return convert_object_to_bytearray(model)
def count_resource_variables(model):
"""Calculates the number of unique resource variables in a model.
Args:
model: the input tflite model, either as bytearray or object.
Returns:
An integer number representing the number of unique resource variables.
"""
if not isinstance(model, schema_fb.ModelT):
model = convert_bytearray_to_object(model)
unique_shared_names = set()
for subgraph in model.subgraphs:
if subgraph.operators is None:
continue
for op in subgraph.operators:
builtin_code = schema_util.get_builtin_code_from_operator_code(
model.operatorCodes[op.opcodeIndex]
)
if builtin_code == schema_fb.BuiltinOperator.VAR_HANDLE:
unique_shared_names.add(op.builtinOptions.sharedName)
return len(unique_shared_names)
OptsT = TypeVar('OptsT')
def get_options_as(
op: Union[schema_fb.Operator, schema_fb.OperatorT], opts_type: Type[OptsT]
) -> Optional[OptsT]:
"""Get the options of an operator as the specified type.
Requested type must be an object-api type (ends in 'T').
Args:
op: The operator to get the options from.
opts_type: The type of the options to get.
Returns:
The options as the specified type, or None if the options are not of the
specified type.
Raises:
ValueError: If the specified type is not a valid options type.
"""
err = ValueError(f'Unsupported options type: {opts_type}')
type_name: str = opts_type.__name__
if not type_name.endswith('T'):
raise err
base_type_name = type_name.removesuffix('T')
is_opt_1_type = hasattr(schema_fb.BuiltinOptions, base_type_name)
if not is_opt_1_type and not hasattr(
schema_fb.BuiltinOptions2, base_type_name
):
raise err
if isinstance(op, schema_fb.Operator):
if not is_opt_1_type:
enum_val = getattr(schema_fb.BuiltinOptions2, base_type_name)
opts_creator = schema_fb.BuiltinOptions2Creator
raw_ops = op.BuiltinOptions2()
actual_enum_val = op.BuiltinOptions2Type()
else:
enum_val = getattr(schema_fb.BuiltinOptions, base_type_name)
opts_creator = schema_fb.BuiltinOptionsCreator
raw_ops = op.BuiltinOptions()
actual_enum_val = op.BuiltinOptionsType()
if raw_ops is None or actual_enum_val != enum_val:
return None
return opts_creator(enum_val, raw_ops)
elif isinstance(op, schema_fb.OperatorT):
if is_opt_1_type:
raw_ops_t = op.builtinOptions
else:
raw_ops_t = op.builtinOptions2
if raw_ops_t is None or not isinstance(raw_ops_t, opts_type):
return None
return raw_ops_t
else:
return None
@@ -0,0 +1,549 @@
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python TF-Lite QuantizationDebugger."""
import collections
import csv
import re
from typing import (Any, Callable, Dict, IO, Iterable, List, Mapping, Optional,
Sequence, Tuple)
import numpy as np
from tensorflow.lite.python import convert
from tensorflow.lite.python import interpreter as _interpreter
from tensorflow.lite.python.metrics import metrics as metrics_stub # type: ignore
from tensorflow.python.util import tf_export
# TODO(b/198099651): move converter implementation out of lite.py
TFLiteConverter = Any # importing tf.lite creates circular dependency
# Returns metrics based on difference of values for quantized/float ops.
_DEFAULT_LAYER_DEBUG_METRICS = {
'num_elements': lambda diffs: diffs.size,
'stddev': np.std,
'mean_error': np.average,
'max_abs_error': lambda diffs: np.max(np.abs(diffs)),
'mean_squared_error': lambda diffs: np.average(diffs**2),
}
_NUMERIC_VERIFY_OP_NAME = 'NumericVerify'
def _get_quant_params(
tensor_detail: Mapping[str, Any]) -> Optional[Tuple[float, int]]:
"""Returns first scale and zero point from tensor detail, if present."""
quant_params = tensor_detail['quantization_parameters']
if not quant_params:
return None
if quant_params['scales'] and quant_params['zero_points']:
return (quant_params['scales'][0], quant_params['zero_points'][0])
return None
@tf_export.tf_export('lite.experimental.QuantizationDebugOptions')
class QuantizationDebugOptions:
"""Debug options to set up a given QuantizationDebugger."""
def __init__(self,
layer_debug_metrics: Optional[Mapping[str,
Callable[[np.ndarray],
float]]] = None,
model_debug_metrics: Optional[Mapping[
str, Callable[[Sequence[np.ndarray], Sequence[np.ndarray]],
float]]] = None,
layer_direct_compare_metrics: Optional[Mapping[str, Callable[
[Sequence[np.ndarray], Sequence[np.ndarray], float, int],
float]]] = None,
denylisted_ops: Optional[List[str]] = None,
denylisted_nodes: Optional[List[str]] = None,
fully_quantize: bool = False) -> None:
"""Initializes debugger options.
Args:
layer_debug_metrics: a dict to specify layer debug functions
{function_name_str: function} where the function accepts result of
NumericVerify Op, which is value difference between float and
dequantized op results. The function returns single scalar value.
model_debug_metrics: a dict to specify model debug functions
{function_name_str: function} where the function accepts outputs from
two models, and returns single scalar value for a metric. (e.g.
accuracy, IoU)
layer_direct_compare_metrics: a dict to specify layer debug functions
{function_name_str: function}. The signature is different from that of
`layer_debug_metrics`, and this one gets passed (original float value,
original quantized value, scale, zero point). The function's
implementation is responsible for correctly dequantize the quantized
value to compare. Use this one when comparing diff is not enough.
(Note) quantized value is passed as int8, so cast to int32 is needed.
denylisted_ops: a list of op names which is expected to be removed from
quantization.
denylisted_nodes: a list of op's output tensor names to be removed from
quantization.
fully_quantize: Bool indicating whether to fully quantize the model.
Besides model body, the input/output will be quantized as well.
Corresponding to mlir_quantize's fully_quantize parameter.
Raises:
ValueError: when there are duplicate keys
"""
self.layer_debug_metrics = layer_debug_metrics
self.model_debug_metrics = model_debug_metrics
self.layer_direct_compare_metrics = layer_direct_compare_metrics
keys = []
for metrics in [
layer_debug_metrics, model_debug_metrics, layer_direct_compare_metrics
]:
if metrics is not None:
keys.extend(metrics.keys())
if len(keys) != len(set(keys)):
raise ValueError('Provided metrics have duplicate keys.')
self.denylisted_ops = denylisted_ops
self.denylisted_nodes = denylisted_nodes
self.fully_quantize = fully_quantize
@tf_export.tf_export('lite.experimental.QuantizationDebugger')
class QuantizationDebugger:
"""Debugger for Quantized TensorFlow Lite debug mode models.
This can run the TensorFlow Lite converted models equipped with debug ops and
collect debug information. This debugger calculates statistics from
user-defined post-processing functions as well as default ones.
"""
def __init__(self,
quant_debug_model_path: Optional[str] = None,
quant_debug_model_content: Optional[bytes] = None,
float_model_path: Optional[str] = None,
float_model_content: Optional[bytes] = None,
debug_dataset: Optional[Callable[
[], Iterable[Sequence[np.ndarray]]]] = None,
debug_options: Optional[QuantizationDebugOptions] = None,
converter: Optional[TFLiteConverter] = None) -> None:
"""Runs the TFLite debugging model with given debug options.
Args:
quant_debug_model_path: Path to the quantized debug TFLite model file.
quant_debug_model_content: Content of the quantized debug TFLite model.
float_model_path: Path to float TFLite model file.
float_model_content: Content of the float TFLite model.
debug_dataset: a factory function that returns dataset generator which is
used to generate input samples (list of np.ndarray) for the model. The
generated elements must have same types and shape as inputs to the
model.
debug_options: Debug options to debug the given model.
converter: Optional, use converter instead of quantized model.
Raises:
ValueError: If the debugger was unable to be created.
Attributes:
layer_statistics: results of error metrics for each NumericVerify op
results. in {layer_name: {metric_name: metric}} format.
model_statistics: results of error metrics for difference between float
and quantized models. in {metric_name: metric} format.
"""
self._data_gen = debug_dataset
self._debug_options = debug_options or QuantizationDebugOptions()
self.converter = None
self.calibrated_model = None
self.float_model = None
self._float_interpreter = None
if converter is not None:
if self._debug_options.model_debug_metrics:
old_optimizations = converter.optimizations
self.converter = self._set_converter_options_for_float(converter)
self.float_model = self.converter.convert()
converter.optimizations = old_optimizations
self.converter = self._set_converter_options_for_calibration(converter)
self.calibrated_model = self.converter.convert()
# Converter should be already set up with all options
self._init_from_converter(
self._debug_options,
self.converter,
self.calibrated_model,
float_model=self.float_model)
else:
self._quant_interpreter = _interpreter.Interpreter(
quant_debug_model_path,
quant_debug_model_content,
experimental_preserve_all_tensors=(
self._debug_options.layer_direct_compare_metrics is not None))
if self._debug_options.model_debug_metrics:
self._float_interpreter = _interpreter.Interpreter(
float_model_path, float_model_content)
self._initialize_stats()
@property
def options(self) -> QuantizationDebugOptions:
return self._debug_options
@options.setter
def options(self, options: QuantizationDebugOptions) -> None:
self._debug_options = options
if not self.converter or not self.calibrated_model:
return
self._init_from_converter(
self._debug_options,
self.converter,
self.calibrated_model,
float_model=self.float_model)
self._initialize_stats()
def _initialize_stats(self):
"""Helper function initializes stats."""
# TODO(b/177749613) : Fix the dependency on tf.lite._get_ops_details()
# Following code is needed to get op's name from the output tensor index,
# since NumericVerify op only provides its quantized input tensor index.
self._defining_op = dict()
for op_info in self._quant_interpreter._get_ops_details(): # pylint: disable=protected-access
self._defining_op.update(
{tensor_idx: op_info['index'] for tensor_idx in op_info['outputs']})
self._numeric_verify_tensor_details = None
self._numeric_verify_op_details = None
if not self._get_numeric_verify_tensor_details():
raise ValueError('Please check if the quantized model is in debug mode')
self._layer_debug_metrics = _DEFAULT_LAYER_DEBUG_METRICS.copy()
if self._debug_options.layer_debug_metrics:
self._layer_debug_metrics.update(self._debug_options.layer_debug_metrics)
self.layer_statistics = None
self.model_statistics = None
self._metrics = metrics_stub.TFLiteMetrics()
self._metrics.increase_counter_debugger_creation()
def _get_quantized_model(self, is_debug: bool) -> bytes:
if not self.converter:
raise ValueError('No converter found, use this function with the '
'converter option in the constructor.')
return convert.mlir_quantize(
self.calibrated_model,
disable_per_channel=self.converter._experimental_disable_per_channel, # pylint: disable=protected-access
fully_quantize=self._debug_options.fully_quantize,
enable_numeric_verify=is_debug,
denylisted_ops=self._debug_options.denylisted_ops,
denylisted_nodes=self._debug_options.denylisted_nodes)
def get_nondebug_quantized_model(self) -> bytes:
"""Returns a non-instrumented quantized model.
Convert the quantized model with the initialized converter and
return bytes for nondebug model. The model will not be instrumented with
numeric verification operations.
Returns:
Model bytes corresponding to the model.
Raises:
ValueError: if converter is not passed to the debugger.
"""
return self._get_quantized_model(is_debug=False)
def get_debug_quantized_model(self) -> bytes:
"""Returns an instrumented quantized model.
Convert the quantized model with the initialized converter and
return bytes for model. The model will be instrumented with numeric
verification operations and should only be used for debugging.
Returns:
Model bytes corresponding to the model.
Raises:
ValueError: if converter is not passed to the debugger.
"""
return self._get_quantized_model(is_debug=True)
def _init_from_converter(self,
options: QuantizationDebugOptions,
converter: TFLiteConverter,
calibrated_model: Optional[bytes] = None,
float_model: Optional[bytes] = None) -> None:
"""Convert the model and apply options.
Converts the quantized model and initializes a quantized model interpreter
with the quantized model. Returns a float model interpreter if float model
is provided.
Args:
options: a QuantizationDebugOptions object.
converter: an initialized tf.lite.TFLiteConverter.
calibrated_model: Calibrated model bytes.
float_model: Float model bytes.
"""
self.quant_model = convert.mlir_quantize(
calibrated_model,
disable_per_channel=converter._experimental_disable_per_channel, # pylint: disable=protected-access
fully_quantize=options.fully_quantize,
enable_numeric_verify=True,
denylisted_ops=options.denylisted_ops,
denylisted_nodes=options.denylisted_nodes)
self._quant_interpreter = _interpreter.Interpreter(
model_content=self.quant_model)
self._float_interpreter = None
if float_model is not None:
self._float_interpreter = _interpreter.Interpreter(
model_content=float_model)
def _set_converter_options_for_float(
self, converter: TFLiteConverter) -> TFLiteConverter:
"""Verify converter options and set required experimental options."""
if converter.optimizations:
converter.optimizations = []
return converter
def _set_converter_options_for_calibration(
self, converter: TFLiteConverter) -> TFLiteConverter:
"""Verify converter options and set required experimental options."""
if not converter.optimizations:
raise ValueError(
'converter object must set optimizations to lite.Optimize.DEFAULT')
if not converter.representative_dataset:
raise ValueError('converter object must set representative_dataset')
converter.experimental_mlir_quantizer = True
converter._experimental_calibrate_only = True # pylint: disable=protected-access
return converter
def run(self) -> None:
"""Runs models and gets metrics."""
self.layer_statistics = self._collect_layer_statistics()
if self._debug_options.model_debug_metrics:
self.model_statistics = self._collect_model_statistics()
def _collect_layer_statistics(self) -> Dict[str, Dict[str, float]]:
"""Collects layer statistics by applying layer debug metrics.
For all data from the given RepresentativeDataset, collect statistics per
example by getting the NumericVerify op results in _quant_interpreter
and calculating layer debug metrics on the results.
Returns:
aggregated per-layer statistics of NumericVerify results.
{layer_name: {metric_name: metric}}
"""
layer_statistics = collections.defaultdict(
lambda: collections.defaultdict(list))
initialize = True
for tensor_data in self._data_gen():
self._set_input_tensors(self._quant_interpreter, tensor_data, initialize)
initialize = False
# Run the model.
self._quant_interpreter.invoke()
# Collect the statistics of this invoke result.
for tensor_detail in self._get_numeric_verify_tensor_details():
tensor_name = tensor_detail['name'] # pytype: disable=unsupported-operands # dynamic-method-lookup
diffs = self._quant_interpreter.get_tensor(tensor_detail['index']) # pytype: disable=unsupported-operands # dynamic-method-lookup
for metric_name, metric_fn in self._layer_debug_metrics.items():
layer_statistics[tensor_name][metric_name].append(metric_fn(diffs))
if self._debug_options.layer_direct_compare_metrics is not None:
for tensor_detail in self._get_numeric_verify_tensor_details():
tensor_name = tensor_detail['name'] # pytype: disable=unsupported-operands # dynamic-method-lookup
op_idx = self._defining_op[tensor_detail['index']] # pytype: disable=unsupported-operands # dynamic-method-lookup
op_detail = self._quant_interpreter._get_op_details(op_idx) # pylint: disable=protected-access
q_idx, f_idx = op_detail['inputs']
quant_input_detail = self._quant_interpreter._get_tensor_details( # pylint: disable=protected-access
q_idx, subgraph_index=0)
for (metric_name, metric_fn
) in self._debug_options.layer_direct_compare_metrics.items():
layer_statistics[tensor_name][metric_name].append(
metric_fn(
self._quant_interpreter.get_tensor(f_idx),
self._quant_interpreter.get_tensor(q_idx),
quant_input_detail['quantization_parameters']['scales'][0],
quant_input_detail['quantization_parameters']['zero_points']
[0]))
# Calculate final aggregated metrics for each layer.
for metrics in layer_statistics.values():
for metric_name in metrics:
metrics[metric_name] = np.nanmean(metrics[metric_name])
return layer_statistics
def _collect_model_statistics(self) -> Dict[str, float]:
"""Collects model output metrics.
For all data from the given RepresentativeDataset, collect all model output
results from float model & quantized debug model, and calculate metrics
by using model output functions. As a result, self.model_results is filled,
where self.model_results[model_output_function_name] = `aggregated model
output function value` (a scalar).
Returns:
aggregated per-model output discrepancy metrics.
{metric_name: aggregated_metric}
"""
model_statistics = collections.defaultdict(list)
initialize = True
for tensor_data in self._data_gen():
# Run quantized debug model and collect output results.
self._set_input_tensors(self._quant_interpreter, tensor_data, initialize)
self._quant_interpreter.invoke()
quant_tensor_data = self._get_output_tensors(self._quant_interpreter)
# Run float model if it's initialized.
float_tensor_data = []
if self._float_interpreter:
self._set_input_tensors(
self._float_interpreter, tensor_data, initialize)
self._float_interpreter.invoke()
float_tensor_data = self._get_output_tensors(self._float_interpreter)
initialize = False
# Calculate the metrics.
for (metric_name,
metric_fn) in self._debug_options.model_debug_metrics.items():
model_statistics[metric_name].append(
metric_fn(float_tensor_data, quant_tensor_data))
# Calculate final aggregated metrics for each outputs.
return {
metric_name: np.mean(metric)
for metric_name, metric in model_statistics.items()
}
def _set_input_tensors(self, interpreter: _interpreter.Interpreter,
tensor_data: Sequence[np.ndarray],
initialize: bool) -> None:
"""Sets input tensors into TFLite model Interpreter.
Args:
interpreter: a tf.lite.Interpreter object with allocated tensors.
tensor_data: a list of Numpy array data.
initialize: set to true when input is first set for the interpreter, to
set input shapes and allocate tensors.
Raises:
ValueError: when inputs can't be set, or size of provided inputs does not
match size of model inputs.
"""
input_details = interpreter.get_input_details()
if len(input_details) != len(tensor_data):
raise ValueError(
'Number of inputs provided ({}) does not match number of inputs to '
'the model ({})'.format(len(tensor_data), len(input_details)))
if initialize:
for input_detail, tensor in zip(input_details, tensor_data):
interpreter.resize_tensor_input(input_detail['index'], tensor.shape)
interpreter.allocate_tensors()
for input_detail, tensor in zip(input_details, tensor_data):
if tensor.dtype == np.float32 and input_detail['dtype'] == np.int8:
quant_params = _get_quant_params(input_detail)
if quant_params:
scale, zero_point = quant_params
tensor = np.round((tensor / scale) + zero_point).astype(np.int8)
interpreter.set_tensor(input_detail['index'], tensor)
def _get_output_tensors(
self, interpreter: _interpreter.Interpreter) -> List[np.ndarray]:
"""Returns output tensors of given TFLite model Interpreter.
Args:
interpreter: a tf.lite.Interpreter object with allocated tensors.
Returns:
a list of numpy arrays representing output tensor results.
"""
outputs = []
for output_detail in interpreter.get_output_details():
tensor = interpreter.get_tensor(output_detail['index'])
if output_detail['dtype'] == np.int8:
quant_params = _get_quant_params(output_detail)
if quant_params:
scale, zero_point = quant_params
tensor = ((tensor.astype(np.float32) - zero_point) * scale).astype(
np.float32)
outputs.append(tensor)
return outputs
def _get_numeric_verify_tensor_details(self) -> List[str]:
"""Returns all names of all tensors from NumericVerify op."""
# pylint: disable=protected-access
if not self._numeric_verify_tensor_details:
self._numeric_verify_tensor_details = []
self._numeric_verify_op_details = {}
for op_info in self._quant_interpreter._get_ops_details():
if op_info['op_name'] == _NUMERIC_VERIFY_OP_NAME:
self._numeric_verify_tensor_details.append(
self._quant_interpreter._get_tensor_details(
op_info['outputs'][0], subgraph_index=0))
tensor_name = self._numeric_verify_tensor_details[-1]['name']
self._numeric_verify_op_details[tensor_name] = op_info
# pylint: enable=protected-access
return self._numeric_verify_tensor_details
def _get_operand_name_and_index(self,
numeric_verify_name: str) -> Tuple[str, int]:
"""Gets the index and name of NumericVerify Op's quantized input tensor.
Args:
numeric_verify_name: name of the NumericVerify op's output tensor. It has
format of `NumericVerify/{quantized_tensor_name}:{quantized_tensor_idx}`
Returns:
Tuple of (tensor_name, tensor_idx) for quantized op's output tensor.
"""
tensor_name, tensor_idx = numeric_verify_name.rsplit(':', 1)
float_tensor_name = tensor_name[len(_NUMERIC_VERIFY_OP_NAME) + 1:]
if re.match(r'\d', float_tensor_name[-1]):
float_tensor_name = float_tensor_name[:-1]
return (float_tensor_name, int(tensor_idx))
def layer_statistics_dump(self, file: IO[str]) -> None:
"""Dumps layer statistics into file, in csv format.
Args:
file: file, or file-like object to write.
"""
# order of `fields` is the order of fields in csv.
fields = ['op_name', 'tensor_idx'] + list(self._layer_debug_metrics.keys())
if self._debug_options.layer_direct_compare_metrics is not None:
fields += list(self._debug_options.layer_direct_compare_metrics.keys())
fields += ['scale', 'zero_point', 'tensor_name']
writer = csv.DictWriter(file, fields)
writer.writeheader()
if self.layer_statistics:
for name, metrics in self.layer_statistics.items():
data = metrics.copy()
(data['tensor_name'], _) = self._get_operand_name_and_index(name)
data['tensor_idx'] = self._numeric_verify_op_details[name]['inputs'][0]
data['op_name'] = self._quant_interpreter._get_op_details( # pylint: disable=protected-access
self._defining_op[data['tensor_idx']])['op_name']
details = self._quant_interpreter._get_tensor_details( # pylint: disable=protected-access
data['tensor_idx'], subgraph_index=0)
data['scale'], data['zero_point'] = (
details['quantization_parameters']['scales'][0],
details['quantization_parameters']['zero_points'][0])
writer.writerow(data)
@@ -0,0 +1,35 @@
from sys import modules
from types import ModuleType
def __update_globals(new_import_path, pywrap_m):
all_names = pywrap_m.__all__ if hasattr(pywrap_m, '__all__') else dir(
pywrap_m)
modules[new_import_path] = pywrap_m
for name in all_names:
sub_pywrap = getattr(pywrap_m, name)
if isinstance(sub_pywrap, ModuleType):
sub_name = sub_pywrap.__name__[len(pywrap_m.__name__):]
__update_globals(new_import_path + sub_name, sub_pywrap)
def __try_import():
imports_paths = ["litert.python._pywrap_modify_model_interface", "third_party.tensorflow.lite.python._pywrap_modify_model_interface", "tensorflow._pywrap_modify_model_interface", "tensorflow.python._pywrap_modify_model_interface"] # template_val
exceptions = []
last_exception = None
for import_path in imports_paths:
try:
pywrap_m = __import__(import_path, fromlist=["*"])
__update_globals(__name__, pywrap_m)
return
except ImportError as e:
exceptions.append(str(e))
last_exception = e
pass
raise RuntimeError(f"""
Could not import original test/binary location, import paths tried: {imports_paths}.
Previous exceptions: {exceptions}""", last_exception)
__try_import()
@@ -0,0 +1,35 @@
from sys import modules
from types import ModuleType
def __update_globals(new_import_path, pywrap_m):
all_names = pywrap_m.__all__ if hasattr(pywrap_m, '__all__') else dir(
pywrap_m)
modules[new_import_path] = pywrap_m
for name in all_names:
sub_pywrap = getattr(pywrap_m, name)
if isinstance(sub_pywrap, ModuleType):
sub_name = sub_pywrap.__name__[len(pywrap_m.__name__):]
__update_globals(new_import_path + sub_name, sub_pywrap)
def __try_import():
imports_paths = ["litert.python.format_converter_wrapper_pybind11", "third_party.tensorflow.lite.python.format_converter_wrapper_pybind11", "tensorflow.format_converter_wrapper_pybind11", "tensorflow.python.format_converter_wrapper_pybind11"] # template_val
exceptions = []
last_exception = None
for import_path in imports_paths:
try:
pywrap_m = __import__(import_path, fromlist=["*"])
__update_globals(__name__, pywrap_m)
return
except ImportError as e:
exceptions.append(str(e))
last_exception = e
pass
raise RuntimeError(f"""
Could not import original test/binary location, import paths tried: {imports_paths}.
Previous exceptions: {exceptions}""", last_exception)
__try_import()
@@ -0,0 +1,582 @@
#!/usr/bin/env python
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""This tool creates an html visualization of a TensorFlow Lite graph.
Example usage:
python visualize.py foo.tflite foo.html
"""
import json
import os
import re
import sys
import numpy as np
# pylint: disable=g-import-not-at-top
if not os.path.splitext(__file__)[0].endswith(
os.path.join("tflite_runtime", "visualize")):
# This file is part of tensorflow package.
from tensorflow.lite.python import schema_py_generated as schema_fb
else:
# This file is part of tflite_runtime package.
from tflite_runtime import schema_py_generated as schema_fb
# A CSS description for making the visualizer
_CSS = """
<html>
<head>
<style>
body {font-family: sans-serif; background-color: #fa0;}
table {background-color: #eca;}
th {background-color: black; color: white;}
/* Constrain table cells to a max size and make them scrollable. */
.data-table td {
max-width: 900px;
}
.data-table .cell-content {
max-height: 200px;
overflow: auto;
white-space: pre-wrap;
word-break: break-all;
}
h1 {
background-color: ffaa00;
padding:5px;
color: black;
}
svg {
margin: 10px;
border: 2px;
border-style: solid;
border-color: black;
background: white;
}
div {
border-radius: 5px;
background-color: #fec;
padding:5px;
margin:5px;
}
.tooltip {color: blue;}
.tooltip .tooltipcontent {
visibility: hidden;
color: black;
background-color: yellow;
padding: 5px;
border-radius: 4px;
position: absolute;
z-index: 1;
}
.tooltip:hover .tooltipcontent {
visibility: visible;
}
.edges line {
stroke: #333;
}
text {
font-weight: bold;
}
.nodes text {
color: black;
pointer-events: none;
font-family: sans-serif;
font-size: 11px;
}
</style>
<script src="https://d3js.org/d3.v4.min.js"></script>
</head>
<body>
"""
_D3_HTML_TEMPLATE = """
<script>
function buildGraph() {
// Build graph data
var graph = %s;
var svg = d3.select("#subgraph%d")
var width = svg.attr("width");
var height = svg.attr("height");
// Make the graph scrollable.
svg = svg.call(d3.zoom().on("zoom", function() {
svg.attr("transform", d3.event.transform);
})).append("g");
var color = d3.scaleOrdinal(d3.schemeDark2);
var simulation = d3.forceSimulation()
.force("link", d3.forceLink().id(function(d) {return d.id;}))
.force("charge", d3.forceManyBody())
.force("center", d3.forceCenter(0.5 * width, 0.5 * height));
var edge = svg.append("g").attr("class", "edges").selectAll("line")
.data(graph.edges).enter().append("path").attr("stroke","black").attr("fill","none")
// Make the node group
var node = svg.selectAll(".nodes")
.data(graph.nodes)
.enter().append("g")
.attr("x", function(d){return d.x})
.attr("y", function(d){return d.y})
.attr("transform", function(d) {
return "translate( " + d.x + ", " + d.y + ")"
})
.attr("class", "nodes")
.call(d3.drag()
.on("start", function(d) {
if(!d3.event.active) simulation.alphaTarget(1.0).restart();
d.fx = d.x;d.fy = d.y;
})
.on("drag", function(d) {
d.fx = d3.event.x; d.fy = d3.event.y;
})
.on("end", function(d) {
if (!d3.event.active) simulation.alphaTarget(0);
d.fx = d.fy = null;
}));
// Within the group, draw a box for the node position and text
// on the side.
var node_width = 150;
var node_height = 30;
node.append("rect")
.attr("r", "5px")
.attr("width", node_width)
.attr("height", node_height)
.attr("rx", function(d) { return d.group == 1 ? 1 : 10; })
.attr("stroke", "#000000")
.attr("fill", function(d) { return d.group == 1 ? "#dddddd" : "#000000"; })
node.append("text")
.text(function(d) { return d.name; })
.attr("x", 5)
.attr("y", 20)
.attr("fill", function(d) { return d.group == 1 ? "#000000" : "#eeeeee"; })
// Setup force parameters and update position callback
var node = svg.selectAll(".nodes")
.data(graph.nodes);
// Bind the links
var name_to_g = {}
node.each(function(data, index, nodes) {
console.log(data.id)
name_to_g[data.id] = this;
});
function proc(w, t) {
return parseInt(w.getAttribute(t));
}
edge.attr("d", function(d) {
function lerp(t, a, b) {
return (1.0-t) * a + t * b;
}
var x1 = proc(name_to_g[d.source],"x") + node_width /2;
var y1 = proc(name_to_g[d.source],"y") + node_height;
var x2 = proc(name_to_g[d.target],"x") + node_width /2;
var y2 = proc(name_to_g[d.target],"y");
var s = "M " + x1 + " " + y1
+ " C " + x1 + " " + lerp(.5, y1, y2)
+ " " + x2 + " " + lerp(.5, y1, y2)
+ " " + x2 + " " + y2
return s;
});
}
buildGraph()
</script>
"""
def TensorTypeToName(tensor_type):
"""Converts a numerical enum to a readable tensor type."""
for name, value in schema_fb.TensorType.__dict__.items():
if value == tensor_type:
return name
return None
def BuiltinCodeToName(code):
"""Converts a builtin op code enum to a readable name."""
for name, value in schema_fb.BuiltinOperator.__dict__.items():
if value == code:
return name
return None
def NameListToString(name_list):
"""Converts a list of integers to the equivalent ASCII string."""
if isinstance(name_list, str):
return name_list
else:
result = ""
if name_list is not None:
for val in name_list:
result = result + chr(int(val))
return result
class OpCodeMapper:
"""Maps an opcode index to an op name."""
def __init__(self, data):
self.code_to_name = {}
for idx, d in enumerate(data["operator_codes"]):
self.code_to_name[idx] = BuiltinCodeToName(d["builtin_code"])
if self.code_to_name[idx] == "CUSTOM":
self.code_to_name[idx] = NameListToString(d["custom_code"])
def __call__(self, x):
if x not in self.code_to_name:
s = "<UNKNOWN>"
else:
s = self.code_to_name[x]
return "%s (%d)" % (s, x)
class DataSizeMapper:
"""For buffers, report the number of bytes."""
def __call__(self, x):
if x is not None:
return "%d bytes" % len(x)
else:
return "--"
class TensorMapper:
"""Maps a list of tensor indices to a tooltip hoverable indicator of more."""
def __init__(self, subgraph_data):
self.data = subgraph_data
def __call__(self, x):
html = ""
if x is None:
return html
html += "<span class='tooltip'><span class='tooltipcontent'>"
for i in x:
tensor = self.data["tensors"][i]
html += str(i) + " "
html += NameListToString(tensor["name"]) + " "
html += TensorTypeToName(tensor["type"]) + " "
html += (repr(tensor["shape"]) if "shape" in tensor else "[]")
html += (repr(tensor["shape_signature"])
if "shape_signature" in tensor else "[]") + "<br>"
html += "</span>"
html += repr(x)
html += "</span>"
return html
def QuantizationMapper(q):
"""Pretty-print the quantization dictionary, truncating large arrays."""
if not q:
return ""
items_str = []
for key, value in q.items():
key_str = repr(key)
# In TFLite, quantization arrays can be large.
if isinstance(value, list) and len(value) > 20:
head = value[:10]
tail = value[-10:]
value_str = (f"[{', '.join(map(repr, head))}, ..., "
f"{', '.join(map(repr, tail))}]")
else:
value_str = repr(value)
items_str.append(f"{key_str}: {value_str}")
return f"{{{', '.join(items_str)}}}"
def GenerateGraph(subgraph_idx, g, opcode_mapper):
"""Produces the HTML required to have a d3 visualization of the dag."""
def TensorName(idx):
return "t%d" % idx
def OpName(idx):
return "o%d" % idx
edges = []
nodes = []
first = {}
second = {}
pixel_mult = 200 # TODO(aselle): multiplier for initial placement
width_mult = 170 # TODO(aselle): multiplier for initial placement
for op_index, op in enumerate(g["operators"] or []):
if op["inputs"] is not None:
for tensor_input_position, tensor_index in enumerate(op["inputs"]):
if tensor_index not in first:
first[tensor_index] = ((op_index - 0.5 + 1) * pixel_mult,
(tensor_input_position + 1) * width_mult)
edges.append({
"source": TensorName(tensor_index),
"target": OpName(op_index)
})
if op["outputs"] is not None:
for tensor_output_position, tensor_index in enumerate(op["outputs"]):
if tensor_index not in second:
second[tensor_index] = ((op_index + 0.5 + 1) * pixel_mult,
(tensor_output_position + 1) * width_mult)
edges.append({
"target": TensorName(tensor_index),
"source": OpName(op_index)
})
nodes.append({
"id": OpName(op_index),
"name": opcode_mapper(op["opcode_index"]),
"group": 2,
"x": pixel_mult,
"y": (op_index + 1) * pixel_mult
})
for tensor_index, tensor in enumerate(g["tensors"]):
initial_y = (
first[tensor_index] if tensor_index in first else
second[tensor_index] if tensor_index in second else (0, 0))
nodes.append({
"id": TensorName(tensor_index),
"name": "%r (%d)" % (getattr(tensor, "shape", []), tensor_index),
"group": 1,
"x": initial_y[1],
"y": initial_y[0]
})
graph_str = json.dumps({"nodes": nodes, "edges": edges})
html = _D3_HTML_TEMPLATE % (graph_str, subgraph_idx)
return html
def GenerateTableHtml(items, keys_to_print, display_index=True):
"""Given a list of object values and keys to print, make an HTML table.
Args:
items: Items to print an array of dicts.
keys_to_print: (key, display_fn). `key` is a key in the object. i.e.
items[0][key] should exist. display_fn is the mapping function on display.
i.e. the displayed html cell will have the string returned by
`mapping_fn(items[0][key])`.
display_index: add a column which is the index of each row in `items`.
Returns:
An html table.
"""
html = ""
# Print the list of items
html += "<table class='data-table'>\n"
html += "<tr>\n"
if display_index:
html += "<th>index</th>"
for h, mapper in keys_to_print:
html += "<th>%s</th>" % h
html += "</tr>\n"
for idx, tensor in enumerate(items):
html += "<tr>\n"
if display_index:
html += "<td>%d</td>" % idx
# print tensor.keys()
for h, mapper in keys_to_print:
val = tensor[h] if h in tensor else None
val = val if mapper is None else mapper(val)
html += "<td><div class='cell-content'>%s</div></td>\n" % val
html += "</tr>\n"
html += "</table>\n"
return html
def CamelCaseToSnakeCase(camel_case_input):
"""Converts an identifier in CamelCase to snake_case."""
s1 = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", camel_case_input)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", s1).lower()
def FlatbufferToDict(fb, preserve_as_numpy):
"""Converts a hierarchy of FB objects into a nested dict.
We avoid transforming big parts of the flat buffer into python arrays. This
speeds conversion from ten minutes to a few seconds on big graphs.
Args:
fb: a flat buffer structure. (i.e. ModelT)
preserve_as_numpy: true if all downstream np.arrays should be preserved.
false if all downstream np.array should become python arrays
Returns:
A dictionary representing the flatbuffer rather than a flatbuffer object.
"""
if isinstance(fb, int) or isinstance(fb, float) or isinstance(fb, str):
return fb
elif hasattr(fb, "__dict__"):
result = {}
for attribute_name in dir(fb):
attribute = fb.__getattribute__(attribute_name)
if not callable(attribute) and attribute_name[0] != "_":
snake_name = CamelCaseToSnakeCase(attribute_name)
preserve = True if attribute_name == "buffers" else preserve_as_numpy
result[snake_name] = FlatbufferToDict(attribute, preserve)
return result
elif isinstance(fb, np.ndarray):
return fb if preserve_as_numpy else fb.tolist()
elif hasattr(fb, "__len__"):
return [FlatbufferToDict(entry, preserve_as_numpy) for entry in fb]
else:
return fb
def CreateDictFromFlatbuffer(buffer_data):
model_obj = schema_fb.Model.GetRootAsModel(buffer_data, 0)
model = schema_fb.ModelT.InitFromObj(model_obj)
return FlatbufferToDict(model, preserve_as_numpy=False)
def create_html(tflite_input, input_is_filepath=True): # pylint: disable=invalid-name
"""Returns html description with the given tflite model.
Args:
tflite_input: TFLite flatbuffer model path or model object.
input_is_filepath: Tells if tflite_input is a model path or a model object.
Returns:
Dump of the given tflite model in HTML format.
Raises:
RuntimeError: If the input is not valid.
"""
# Convert the model into a JSON flatbuffer using flatc (build if doesn't
# exist.
if input_is_filepath:
if not os.path.exists(tflite_input):
raise RuntimeError("Invalid filename %r" % tflite_input)
if tflite_input.endswith(".tflite") or tflite_input.endswith(".bin"):
with open(tflite_input, "rb") as file_handle:
file_data = bytearray(file_handle.read())
data = CreateDictFromFlatbuffer(file_data)
elif tflite_input.endswith(".json"):
data = json.load(open(tflite_input))
else:
raise RuntimeError("Input file was not .tflite or .json")
else:
data = CreateDictFromFlatbuffer(tflite_input)
html = ""
html += _CSS
html += "<h1>TensorFlow Lite Model</h2>"
data["filename"] = tflite_input if input_is_filepath else (
"Null (used model object)") # Avoid special case
toplevel_stuff = [("filename", None), ("version", None),
("description", None)]
html += "<table class='data-table'>\n"
for key, mapping in toplevel_stuff:
if not mapping:
mapping = lambda x: x
val = mapping(data.get(key))
html += ("<tr><th>%s</th><td><div class='cell-content'>%s</div></td></tr>\n"
% (key, val))
html += "</table>\n"
# Spec on what keys to display
buffer_keys_to_display = [("data", DataSizeMapper())]
operator_keys_to_display = [("builtin_code", BuiltinCodeToName),
("custom_code", NameListToString),
("version", None)]
# Update builtin code fields.
for d in data["operator_codes"]:
d["builtin_code"] = max(d["builtin_code"], d["deprecated_builtin_code"])
for subgraph_idx, g in enumerate(data["subgraphs"]):
# Subgraph local specs on what to display
html += "<div class='subgraph'>"
tensor_mapper = TensorMapper(g)
opcode_mapper = OpCodeMapper(data)
op_keys_to_display = [("inputs", tensor_mapper), ("outputs", tensor_mapper),
("builtin_options", None),
("opcode_index", opcode_mapper)]
tensor_keys_to_display = [("name", NameListToString),
("type", TensorTypeToName), ("shape", None),
("shape_signature", None), ("buffer", None),
("quantization", QuantizationMapper)]
html += "<h2>Subgraph %d</h2>\n" % subgraph_idx
# Inputs and outputs.
html += "<h3>Inputs/Outputs</h3>\n"
html += GenerateTableHtml([{
"inputs": g["inputs"],
"outputs": g["outputs"]
}], [("inputs", tensor_mapper), ("outputs", tensor_mapper)],
display_index=False)
# Print the tensors.
html += "<h3>Tensors</h3>\n"
html += GenerateTableHtml(g["tensors"], tensor_keys_to_display)
# Print the ops.
if g["operators"]:
html += "<h3>Ops</h3>\n"
html += GenerateTableHtml(g["operators"], op_keys_to_display)
# Visual graph.
html += "<svg id='subgraph%d' width='1600' height='900'></svg>\n" % (
subgraph_idx,)
html += GenerateGraph(subgraph_idx, g, opcode_mapper)
html += "</div>"
# Buffers have no data, but maybe in the future they will
html += "<h2>Buffers</h2>\n"
html += GenerateTableHtml(data["buffers"], buffer_keys_to_display)
# Operator codes
html += "<h2>Operator Codes</h2>\n"
html += GenerateTableHtml(data["operator_codes"], operator_keys_to_display)
html += "</body></html>\n"
return html
def main(argv):
try:
tflite_input = argv[1]
html_output = argv[2]
except IndexError:
print("Usage: %s <input tflite> <output html>" % (argv[0]))
else:
html = create_html(tflite_input)
with open(html_output, "w") as output_file:
output_file.write(html)
if __name__ == "__main__":
main(sys.argv)