import ast from copy import copy import itertools import json import os import struct import numpy as np import pandas as pd from pandas.core.arrays.masked import BaseMaskedDtype from fastparquet.util import join_path from fastparquet import parquet_thrift, __version__, cencoding from fastparquet.api import ParquetFile, partitions, part_ids from fastparquet.compression import compress_data from fastparquet.converted_types import tobson from fastparquet.json import json_encoder from fastparquet.util import (default_open, default_mkdirs, check_column_names, created_by, get_column_metadata, norm_col_name, path_string, reset_row_idx, get_fs, update_custom_metadata) from fastparquet.speedups import array_encode_utf8, pack_byte_array from fastparquet.cencoding import NumpyIO, ThriftObject, from_buffer from decimal import Decimal MARKER = b'PAR1' ROW_GROUP_SIZE = 50_000_000 NaT = np.timedelta64(None).tobytes() # require numpy version >= 1.7 nat = np.datetime64('NaT').view('int64') typemap = { # primitive type, converted type, bit width 'boolean': (parquet_thrift.Type.BOOLEAN, None, 1), 'Int32': (parquet_thrift.Type.INT32, None, 32), 'Int64': (parquet_thrift.Type.INT64, None, 64), 'Int8': (parquet_thrift.Type.INT32, parquet_thrift.ConvertedType.INT_8, 8), 'Int16': (parquet_thrift.Type.INT32, parquet_thrift.ConvertedType.INT_16, 16), 'UInt8': (parquet_thrift.Type.INT32, parquet_thrift.ConvertedType.UINT_8, 8), 'UInt16': (parquet_thrift.Type.INT32, parquet_thrift.ConvertedType.UINT_16, 16), 'UInt32': (parquet_thrift.Type.INT32, parquet_thrift.ConvertedType.UINT_32, 32), 'UInt64': (parquet_thrift.Type.INT64, parquet_thrift.ConvertedType.UINT_64, 64), 'bool': (parquet_thrift.Type.BOOLEAN, None, 1), 'int32': (parquet_thrift.Type.INT32, None, 32), 'int64': (parquet_thrift.Type.INT64, None, 64), 'int8': (parquet_thrift.Type.INT32, parquet_thrift.ConvertedType.INT_8, 8), 'int16': (parquet_thrift.Type.INT32, parquet_thrift.ConvertedType.INT_16, 16), 'uint8': (parquet_thrift.Type.INT32, parquet_thrift.ConvertedType.UINT_8, 8), 'uint16': (parquet_thrift.Type.INT32, parquet_thrift.ConvertedType.UINT_16, 16), 'uint32': (parquet_thrift.Type.INT32, parquet_thrift.ConvertedType.UINT_32, 32), 'uint64': (parquet_thrift.Type.INT64, parquet_thrift.ConvertedType.UINT_64, 64), 'float32': (parquet_thrift.Type.FLOAT, None, 32), 'float64': (parquet_thrift.Type.DOUBLE, None, 64), 'float16': (parquet_thrift.Type.FLOAT, None, 16), 'Float32': (parquet_thrift.Type.FLOAT, None, 32), 'Float64': (parquet_thrift.Type.DOUBLE, None, 64), 'Float16': (parquet_thrift.Type.FLOAT, None, 16), } revmap = {parquet_thrift.Type.INT32: np.int32, parquet_thrift.Type.INT64: np.int64, parquet_thrift.Type.FLOAT: np.float32, parquet_thrift.Type.DOUBLE: np.float64} pdoptional_to_numpy_typemap = { pd.Int8Dtype(): np.int8, pd.Int16Dtype(): np.int16, pd.Int32Dtype(): np.int32, pd.Int64Dtype(): np.int64, pd.UInt8Dtype(): np.uint8, pd.UInt16Dtype(): np.uint16, pd.UInt32Dtype(): np.uint32, pd.UInt64Dtype(): np.uint64, pd.BooleanDtype(): bool } def find_type(data, fixed_text=None, object_encoding=None, times='int64', is_index:bool=None): """ Get appropriate typecodes for column dtype Data conversion do not happen here, see convert(). The user is expected to transform their data into the appropriate dtype before saving to parquet, we will not make any assumptions for them. Known types that cannot be represented (must be first converted another type or to raw binary): float128, complex Parameters ---------- data: pd.Series fixed_text: int or None For str and bytes, the fixed-string length to use. If None, object column will remain variable length. object_encoding: None or infer|bytes|utf8|json|bson|bool|int|int32|float How to encode object type into bytes. If None, bytes is assumed; if 'infer', type is guessed from 10 first non-null values. times: 'int64'|'int96' Normal integers or 12-byte encoding for timestamps. is_index: bool, optional Set `True` if column storing a row index, `False` otherwise. Required if column name is a tuple (when dataframe managed has a column multi-index). In this case, with this flag set `True`, name of columns used to store a row index are reset from tuple to simple string. Returns ------- - a thrift schema element - a thrift typecode to be passed to the column chunk writer - converted data (None if convert is False) """ dtype = data.dtype logical_type = None if dtype.name in typemap: type, converted_type, width = typemap[dtype.name] elif "S" in str(dtype)[:2] or "U" in str(dtype)[:2]: type, converted_type, width = (parquet_thrift.Type.FIXED_LEN_BYTE_ARRAY, None, dtype.itemsize) elif dtype == "O": if object_encoding == 'infer': object_encoding = infer_object_encoding(data) if object_encoding == 'utf8': type, converted_type, width = (parquet_thrift.Type.BYTE_ARRAY, parquet_thrift.ConvertedType.UTF8, None) elif object_encoding in ['bytes', None]: type, converted_type, width = (parquet_thrift.Type.BYTE_ARRAY, None, None) elif object_encoding == 'json': type, converted_type, width = (parquet_thrift.Type.BYTE_ARRAY, parquet_thrift.ConvertedType.JSON, None) elif object_encoding == 'bson': type, converted_type, width = (parquet_thrift.Type.BYTE_ARRAY, parquet_thrift.ConvertedType.BSON, None) elif object_encoding == 'bool': type, converted_type, width = (parquet_thrift.Type.BOOLEAN, None, 1) elif object_encoding == 'int': type, converted_type, width = (parquet_thrift.Type.INT64, None, 64) elif object_encoding == 'int32': type, converted_type, width = (parquet_thrift.Type.INT32, None, 32) elif object_encoding == 'float': type, converted_type, width = (parquet_thrift.Type.DOUBLE, None, 64) elif object_encoding == 'decimal': type, converted_type, width = (parquet_thrift.Type.DOUBLE, None, 64) else: raise ValueError('Object encoding (%s) not one of ' 'infer|utf8|bytes|json|bson|bool|int|int32|float|decimal' % object_encoding) if fixed_text: width = fixed_text type = parquet_thrift.Type.FIXED_LEN_BYTE_ARRAY elif dtype.kind == "M": if times == 'int64': # output will have the same resolution as original data, for resolution <= ms tz = getattr(dtype, "tz", None) is not None if "ns" in dtype.str: type = parquet_thrift.Type.INT64 converted_type = None logical_type = parquet_thrift.LogicalType( TIMESTAMP=parquet_thrift.TimestampType( isAdjustedToUTC=tz, unit=parquet_thrift.TimeUnit(NANOS=parquet_thrift.NanoSeconds()) ) ) width = None elif "us" in dtype.str: type, converted_type, width = ( parquet_thrift.Type.INT64, parquet_thrift.ConvertedType.TIMESTAMP_MICROS, None ) logical_type = ThriftObject.from_fields( "LogicalType", TIMESTAMP=ThriftObject.from_fields( "TimestampType", isAdjustedToUTC=tz, unit=ThriftObject.from_fields("TimeUnit", MICROS={}) ) ) else: type, converted_type, width = ( parquet_thrift.Type.INT64, parquet_thrift.ConvertedType.TIMESTAMP_MILLIS, None ) logical_type = ThriftObject.from_fields( "LogicalType", TIMESTAMP=ThriftObject.from_fields( "TimestampType", isAdjustedToUTC=tz, unit=ThriftObject.from_fields("TimeUnit", MILLIS={}) ) ) elif times == 'int96': type, converted_type, width = (parquet_thrift.Type.INT96, None, None) else: raise ValueError( "Parameter times must be [int64|int96], not %s" % times) # warning removed as irrelevant for most users # if hasattr(dtype, 'tz') and str(dtype.tz) != 'UTC': # warnings.warn( # 'Coercing datetimes to UTC before writing the parquet file, the timezone is stored in the metadata. ' # 'Reading back with fastparquet/pyarrow will restore the timezone properly.' # ) elif dtype.kind == "m": type, converted_type, width = (parquet_thrift.Type.INT64, parquet_thrift.ConvertedType.TIME_MICROS, None) elif "str" in str(dtype): type, converted_type, width = (parquet_thrift.Type.BYTE_ARRAY, parquet_thrift.ConvertedType.UTF8, None) else: raise ValueError("Don't know how to convert data type: %s" % dtype) se = parquet_thrift.SchemaElement( name=norm_col_name(data.name, is_index), type_length=width, converted_type=converted_type, type=type, repetition_type=parquet_thrift.FieldRepetitionType.REQUIRED, logicalType=logical_type, i32=True ) return se, type def convert(data, se): """Convert data according to the schema encoding""" dtype = data.dtype type = se.type converted_type = se.converted_type if dtype.name in typemap: if type in revmap: out = data.values.astype(revmap[type], copy=False) elif type == parquet_thrift.Type.BOOLEAN: # TODO: with our own bitpack writer, no need to copy for # the padding padded = np.pad(data.values, (0, 8 - (len(data) % 8)), 'constant', constant_values=(0, 0)) out = np.packbits(padded.reshape(-1, 8)[:, ::-1].ravel()) elif dtype.name in typemap: out = data.values elif "S" in str(dtype)[:2] or "U" in str(dtype)[:2]: out = data.values elif dtype == "O": # TODO: nullable types try: if converted_type == parquet_thrift.ConvertedType.UTF8: # getattr for new pandas StringArray # TODO: to bytes in one step out = array_encode_utf8(data) elif converted_type == parquet_thrift.ConvertedType.DECIMAL: out = data.values.astype(np.float64, copy=False) elif converted_type is None: if type in revmap: out = data.values.astype(revmap[type], copy=False) elif type == parquet_thrift.Type.BOOLEAN: # TODO: with our own bitpack writer, no need to copy for # the padding padded = np.pad(data.values, (0, 8 - (len(data) % 8)), 'constant', constant_values=(0, 0)) out = np.packbits(padded.reshape(-1, 8)[:, ::-1].ravel()) else: out = data.values elif converted_type == parquet_thrift.ConvertedType.JSON: encoder = json_encoder() # TODO: avoid list. np.fromiter can be used with numpy >= 1.23.0, # but older versions don't support object arrays. out = np.array([encoder(x) for x in data], dtype="O") elif converted_type == parquet_thrift.ConvertedType.BSON: out = data.map(tobson).values if type == parquet_thrift.Type.FIXED_LEN_BYTE_ARRAY: out = out.astype('S%i' % se.type_length) except Exception as e: ct = parquet_thrift.ConvertedType._VALUES_TO_NAMES[ converted_type] if converted_type is not None else None raise ValueError('Error converting column "%s" to bytes using ' 'encoding %s. Original error: ' '%s' % (data.name, ct, e)) elif "str" in str(dtype): try: if converted_type == parquet_thrift.ConvertedType.UTF8: # TODO: into bytes in one step out = array_encode_utf8(data) elif converted_type is None: out = data.values if type == parquet_thrift.Type.FIXED_LEN_BYTE_ARRAY: out = out.astype('S%i' % se.type_length) except Exception as e: # pragma: no cover ct = parquet_thrift.ConvertedType._VALUES_TO_NAMES[ converted_type] if converted_type is not None else None raise ValueError('Error converting column "%s" to bytes using ' 'encoding %s. Original error: ' '%s' % (data.name, ct, e)) elif converted_type == parquet_thrift.ConvertedType.TIME_MICROS: # TODO: shift inplace if data.dtype == "m8[ns]": out = np.empty(len(data), 'int64') time_shift(data.values.view('int64'), out) else: # assuming ms or us out = data.values elif type == parquet_thrift.Type.INT96 and dtype.kind == 'M': ns_per_day = (24 * 3600 * 1000000000) day = data.values.view('int64') // ns_per_day + 2440588 ns = (data.values.view('int64') % ns_per_day) # - ns_per_day // 2 out = np.empty(len(data), dtype=[('ns', 'i8'), ('day', 'i4')]) out['ns'] = ns out['day'] = day elif dtype.kind == "M": part = str(dtype).split("[")[1][:-1].split(",")[0] if converted_type: factor = time_factors[(converted_type, part)] else: unit = [k for k, v in se.logicalType.TIMESTAMP.unit._asdict().items() if v is not None][0] factor = time_factors[(unit, part)] out = data.values.view("int64") // factor else: raise ValueError("Don't know how to convert data type: %s" % dtype) return out time_factors = { ("NANOS", "ns"): 1, (parquet_thrift.ConvertedType.TIMESTAMP_MICROS, "us"): 1, (parquet_thrift.ConvertedType.TIMESTAMP_MICROS, "ns"): 1000, (parquet_thrift.ConvertedType.TIMESTAMP_MILLIS, "ms"): 1, (parquet_thrift.ConvertedType.TIMESTAMP_MILLIS, "s"): 1000, } def infer_object_encoding(data): """Guess object type from first 10 non-na values by iteration""" if data.empty: return "utf8" t = None s = 0 encs = { str: "utf8", bytes: "bytes", list: "json", dict: "json", bool: "bool", Decimal: "decimal", int: "int", float: "float", np.floating: "float", np.str_: "utf8" } for i in data: try: if i is None or i is pd.NA or i is pd.NaT or i is np.nan or pd.isna(i): continue except (ValueError, TypeError): pass tt = type(i) if tt in encs: tt = encs[tt] if t is None: t = tt elif t != tt: raise ValueError("Can't infer object conversion type: %s" % data) s += 1 else: raise ValueError("Can't infer object conversion type: %s" % data) if s > 10: break return t def time_shift(indata, outdata, factor=1000): outdata.view("int64")[:] = np.where( indata.view('int64') == nat, nat, indata.view('int64') // factor ) def encode_plain(data, se): """PLAIN encoding; returns byte representation""" out = convert(data, se) if se.type == parquet_thrift.Type.BYTE_ARRAY: return pack_byte_array(list(out)) else: return out.tobytes() def encode_dict(data, _): """ The data part of dictionary encoding is always int8/16, with RLE/bitpack """ width = data.values.dtype.itemsize * 8 buf = np.empty(10, dtype=np.uint8) o = NumpyIO(buf) o.write_byte(width) bit_packed_count = (len(data) + 7) // 8 cencoding.encode_unsigned_varint(bit_packed_count << 1 | 1, o) # write run header # TODO: `bytes`, `tobytes` makes copy, and adding bytes also makes copy return bytes(o.so_far()) + data.values.tobytes() encode = { 'PLAIN': encode_plain, 'RLE_DICTIONARY': encode_dict, } def make_definitions(data, no_nulls, datapage_version=1): """For data that can contain NULLs, produce definition levels binary data: either bitpacked bools, or (if number of nulls == 0), single RLE block.""" buf = np.empty(10, dtype=np.uint8) temp = NumpyIO(buf) if no_nulls: # no nulls at all l = len(data) cencoding.encode_unsigned_varint(l << 1, temp) temp.write_byte(1) if datapage_version == 1: # TODO: adding bytes causes copy block = struct.pack('>> fastparquet.write('myfile.parquet', df) # doctest: +SKIP """ custom_metadata = custom_metadata or {} if getattr(data, "attrs", None): custom_metadata["PANDAS_ATTRS"] = json.dumps(data.attrs) if file_scheme not in ('simple', 'hive', 'drill'): raise ValueError( 'File scheme should be simple|hive|drill, not ' f'{file_scheme}.') fs, filename, open_with, mkdirs = get_fs(filename, open_with, mkdirs) if append == 'overwrite': overwrite(dirpath=filename, data=data, row_group_offsets=row_group_offsets, compression=compression, open_with=open_with, mkdirs=mkdirs, remove_with=None, stats=stats) return if isinstance(partition_on, str): partition_on = [partition_on] if append: pf = ParquetFile(filename, open_with=open_with) if pf._get_index(): # Format dataframe (manage row index). data = reset_row_idx(data) if file_scheme == 'simple': # Case 'simple' if pf.file_scheme not in ['simple', 'empty']: raise ValueError( 'File scheme requested is simple, but ' f'existing file scheme is {pf.file_scheme}.') else: # Case 'hive', 'drill' if pf.file_scheme not in ['hive', 'empty', 'flat']: raise ValueError(f'Requested file scheme is {file_scheme}, but' ' existing file scheme is not.') if tuple(partition_on) != tuple(pf.cats): raise ValueError('When appending, partitioning columns must ' 'match existing data') pf.write_row_groups(data, row_group_offsets, sort_key=None, sort_pnames=False, compression=compression, write_fmd=True, open_with=open_with, mkdirs=mkdirs, stats=stats) else: # Case 'append=False'. # Define 'index_cols' to be recorded in metadata. cols_dtype = data.columns.dtype if (write_index or write_index is None and not isinstance(data.index, pd.RangeIndex)): # Keep name(s) of index to metadata. cols = set(data) data = reset_row_idx(data) index_cols = [c for c in data if c not in cols] elif write_index is None and isinstance(data.index, pd.RangeIndex): # write_index=None, range to metadata index_cols = data.index else: # write_index=False index_cols = [] # Initialize common metadata. if str(has_nulls) == 'infer': has_nulls = None check_column_names(data.columns, partition_on, fixed_text, object_encoding, has_nulls) ignore = partition_on if file_scheme != 'simple' else [] fmd = make_metadata(data, has_nulls=has_nulls, ignore_columns=ignore, fixed_text=fixed_text, object_encoding=object_encoding, times=times, index_cols=index_cols, partition_cols=partition_on, cols_dtype=cols_dtype) if custom_metadata: kvm = fmd.key_value_metadata or [] kvm.extend( [ parquet_thrift.KeyValue(key=key, value=value) for key, value in custom_metadata.items() ] ) fmd.key_value_metadata = kvm if file_scheme == 'simple': # Case 'simple' write_simple(filename, data, fmd, row_group_offsets=row_group_offsets, compression=compression, open_with=open_with, has_nulls=None, append=False, stats=stats) else: # Case 'hive', 'drill' write_multi(filename, data, fmd, row_group_offsets=row_group_offsets, compression=compression, file_scheme=file_scheme, write_fmd=True, open_with=open_with, mkdirs=mkdirs, partition_on=partition_on, append=False, stats=stats) def find_max_part(row_groups): """ Find the highest integer matching "**part.*.parquet" in referenced paths. """ pids = part_ids(row_groups) if pids: return max(pids) + 1 else: return 0 def partition_on_columns(data, columns, root_path, partname, fmd, compression, open_with, mkdirs, with_field=True, stats=True): """ Split each row-group by the given columns Each combination of column values (determined by pandas groupby) will be written in structured directories. """ # Pandas groupby has by default 'sort=True' meaning groups are sorted # between them on key. gb = data.groupby(columns if len(columns) > 1 else columns[0], observed=False) remaining = list(data) for column in columns: remaining.remove(column) if not remaining: raise ValueError("Cannot include all columns in partition_on") rgs = [] for key, group in sorted(gb): if group.empty: continue df = group[remaining] if not isinstance(key, tuple): key = (key,) if with_field: path = join_path(*( "%s=%s" % (name, path_string(val)) for name, val in zip(columns, key) )) else: path = join_path(*("%s" % val for val in key)) relname = join_path(path, partname) mkdirs(join_path(root_path, path)) fullname = join_path(root_path, path, partname) with open_with(fullname, 'wb') as f2: rg = make_part_file(f2, df, fmd.schema, compression=compression, fmd=fmd, stats=stats) if rg is not None: for chunk in rg.columns: chunk.file_path = relname rgs.append(rg) return rgs def write_common_metadata(fn, fmd, open_with=default_open, no_row_groups=True): """ For hive-style parquet, write schema in special shared file Parameters ---------- fn: str Filename to write to fmd: thrift FileMetaData Information to write open_with: func To use to create writable file as f(path, mode) no_row_groups: bool (True) Strip out row groups from metadata before writing - used for "common metadata" files, containing only the schema. """ consolidate_categories(fmd) with open_with(fn, 'wb') as f: f.write(MARKER) if no_row_groups: fmd = copy(fmd) fmd.row_groups = [] foot_size = write_thrift(f, fmd) else: foot_size = write_thrift(f, fmd) f.write(struct.pack(b" cat['metadata'][ 'num_categories']: cat['metadata']['num_categories'] = int(ncats[0]) key_value[2] = json.dumps(meta, sort_keys=True).encode() def merge(file_list, verify_schema=True, open_with=default_open, root=False): """ Create a logical data-set out of multiple parquet files. The files referenced in file_list must either be in the same directory, or at the same level within a structured directory, where the directories give partitioning information. The schemas of the files should also be consistent. Parameters ---------- file_list: list of paths or ParquetFile instances verify_schema: bool (True) If True, will first check that all the schemas in the input files are identical. open_with: func Used for opening a file for writing as f(path, mode). If input list is ParquetFile instances, will be inferred from the first one of these. root: str If passing a list of files, the top directory of the data-set may be ambiguous for partitioning where the upmost field has only one value. Use this to specify the data'set root directory, if required. Returns ------- ParquetFile instance corresponding to the merged data. """ out = ParquetFile(file_list, verify_schema, open_with, root) out._write_common_metadata(open_with) return out def overwrite(dirpath, data, row_group_offsets=None, sort_pnames:bool=True, compression=None, open_with=default_open, mkdirs=None, remove_with=None, stats=True): """Merge new data to existing parquet dataset. This function requires existing data on disk, written with 'hive' format. This function is a work-in-progress. Several update modes can be envisaged and in the mid term, this function will provide a skeleton for achieving update of an existing dataset with new data. With current version, the only *update mode* supported is ``overwrite_partitioned``. With this mode, row-groups on disk that have partition values overlapping with those of new data are removed first before new data is added. Parameters ---------- dirpath : str Directory path containing a parquet dataset, written with hive format, and with defined partitions. data : pandas dataframe The table to write. row_group_offsets : int or list of int, optional If int, row-groups will be approximately this many rows, rounded down to make row groups about the same size; If a list, the explicit index values to start new row groups; If `None`, set to 50_000_000. In case of partitioning the data, final row-groups size can be reduced significantly further by the partitioning, occuring as a subsequent step. sort_pnames: bool, default True Align name of part files with position of the 1st row group they contain. compression : str or dict, optional Compression to apply to each column, e.g. ``GZIP`` or ``SNAPPY`` or a ``dict`` like ``{"col1": "SNAPPY", "col2": None}`` to specify per column compression types. By default, do not compress. Please, review full description of this parameter in `write` docstring. open_with : function, optional When called with a f(path, mode), returns an open file-like object. mkdirs : function, optional When called with a path/URL, creates any necessary dictionaries to make that location writable, e.g., ``os.makedirs``. remove_with : function, optional When called with f(path), removes file or directory specified by `path` (and any contained files). stats: True|False|list of str Whether to calculate and write summary statistics. If True (default), do it for every column; If False, never do; And if a list of ``str``, do it only for those specified columns. """ pf = ParquetFile(dirpath, open_with=open_with) if (pf.file_scheme == 'simple' or (pf.file_scheme == 'empty' and pf.fn[-9:] != '_metadata')): raise ValueError('Not possible to overwrite with simple file ' 'scheme.') defined_partitions = list(pf.cats) if not defined_partitions: raise ValueError('No partitioning column has been set in existing ' 'dataset. Overwrite of partitions is not possible.') # 1st step (from existing data). # Define 'sort_key' function to be used to sort all row groups once those # of new data will have been added. # 'partitions_starts' is a `dict` that keeps index of 1st row group for # each partition in existing data. n_rgs = len(pf.row_groups) max_idx = n_rgs-1 partitions_starts = {partitions(rg): (max_idx-i) for i, rg in enumerate(reversed(pf.row_groups))} def sort_key(row_group) -> int: """Return 1st row-group index with same partition. If no partition matching, returns an index larger than the 1st row-group indexes of any existing partitions. """ # Taking n_rgs (=len(pf.row_groups)) as index for row-groups without # matching partition among existing ones is overkill but works. rg_partition = partitions(row_group) return (partitions_starts[rg_partition] if (rg_partition in partitions_starts) else n_rgs) # 2nd step (from new and existing data). # Identify row groups from existing data with same partition values as # those in new data. partition_values_in_new = pd.unique(data.loc[:,defined_partitions] .astype(str).agg('/'.join, axis=1)) rgs_to_remove = filter(lambda rg : (partitions(rg, True) in partition_values_in_new), pf.row_groups) # 3rd step (on new data). # Format new data so that it can be written to disk. if pf._get_index(): # Reset index of pandas dataframe. data = reset_row_idx(data) # 4th step: write new data, remove previously existing row groups, # sort row groups and write updated metadata. pf.write_row_groups(data, row_group_offsets=row_group_offsets, sort_key=sort_key, compression=compression, write_fmd=False, open_with=open_with, mkdirs=mkdirs, stats=stats) pf.remove_row_groups(rgs_to_remove, sort_pnames=sort_pnames, write_fmd=True, open_with=open_with, remove_with=remove_with) def write_thrift(f, obj): # TODO inline this if obj.thrift_name == "FileMetaData": for kv in obj.key_value_metadata: if not isinstance(kv.key, (bytes, str)): raise TypeError(f"KeyValue key expected `str` or `bytes`, got: {kv.key!r}") if not isinstance(kv.value, (bytes, str)): raise TypeError(f"KeyValue value expected `str` or `bytes`, got: {kv.value!r}") return f.write(obj.to_bytes()) def update_file_custom_metadata(path: str, custom_metadata: dict, is_metadata_file: bool = None): """Update metadata in file without rewriting data portion if a data file. This function updates only the user key-values metadata, not the whole metadata of a parquet file. Update strategy depends if key found in new custom metadata is also found in already existing custom metadata within thrift object, as well as its value. - If not found in existing, it is added. - If found in existing, it is updated. - If its value is `None`, it is not added, and if found in existing, it is removed from existing. Parameters ---------- path : str Local path to file. custom_metadata : dict Key-value metadata to update in thrift object. The values must be strings or binary. To pass a dictionary, serialize it as json string then encode it in binary. is_metadata_file : bool, default None Define if target file is a pure metadata file, or is a parquet data file. If `None`, is set depending file name. - if ending with '_metadata', it assumes file is a metadata file. - otherwise, it assumes it is a parquet data file. Notes ----- This method does not work for remote files. """ if is_metadata_file is None: if path[-9:] == '_metadata': is_metadata_file = True else: is_metadata_file = False with open(path, "rb+") as f: if is_metadata_file: # For pure metadata file, metadata starts just four bytes in. loc = 4 else: loc0 = f.seek(-8, 2) size = int.from_bytes(f.read(4), "little") loc = loc0 - size f.seek(loc) data = f.read() fmd = from_buffer(data, "FileMetaData") update_custom_metadata(fmd, custom_metadata) f.seek(loc) foot_size = write_thrift(f, fmd) f.write(struct.pack(b" 2**31: raise OverflowError return x