Batch

This module allows automatic splitting of a DataFrame into smaller DataFrames (by clusters of columns) and doing model training and text generation on each sub-DF independently.

Then we can concat each sub-DF back into one final synthetic dataset.

For example usage, please see our Jupyter Notebook.

class gretel_synthetics.batch.Batch(checkpoint_dir: str, input_data_path: str, headers: List[str], config: TensorFlowConfig, gen_data_count: int = 0)

A representation of a synthetic data workflow. It should not be used directly. This object is created automatically by the primary batch handler, such as DataFrameBatch. This class holds all of the necessary information for training, data generation and DataFrame re-assembly.

add_valid_data(data: GenText)

Take a gen_text object and add the generated line to the generated data stream

get_validator()

If a custom validator is set, we return that. Otherwise, we return the built-in validator, which simply checks if a generated line has the right number of values based on the number of headers for this batch.

This at least makes sure the resulting DataFrame will be the right shape

load_validator_from_file()

Load a saved validation object if it exists

reset_gen_data()

Reset all objects that accumulate or track synthetic data generation

set_validator(fn: Callable, save=True)

Assign a validation callable to this batch. Optionally pickling and saving the validator for loading later

property synthetic_df: DataFrame

Get a DataFrame constructed from the generated lines

class gretel_synthetics.batch.DataFrameBatch(*, df: DataFrame | None = None, batch_size: int = 15, batch_headers: List[List[str]] | None = None, config: dict | BaseConfig | None = None, tokenizer: BaseTokenizerTrainer | None = None, mode: str = 'write', checkpoint_dir: str | None = None, validate_model: bool = True)

Create a multi-batch trainer / generator. When created, the directory structure to store models and training data will automatically be created. The directory structure will be created under the “checkpoint_dir” location provided in the config template. There will be one directory per batch, where each directory will be called “batch_N” where N is the batch number, starting from 0.

Training and generating can happen per-batch or we can loop over all batches to do both train / generation functions.

Example

When creating this object, you must explicitly create the training data from the input DataFrame before training models:

my_batch = DataFrameBatch(df=my_df, config=my_config)
my_batch.create_training_data()
my_batch.train_all_batches()
Parameters:
  • df – The input, source DataFrame

  • batch_size – If batch_headers is not provided we automatically break up the number of columns in the source DataFrame into batches of N columns.

  • batch_headers – A list of lists of strings can be provided which will control the number of batches. The number of inner lists is the number of batches, and each inner list represents the columns that belong to that batch

  • config – A template training config to use, this will be used as kwargs for each Batch’s synthetic configuration. This may also be a sucblass of BaseConfig. If this is used, you can set the input_data_path param to the constant PATH_HOLDER as it does not really matter

  • tokenizer_class – An optional BaseTokenizerTrainer subclass. If not provided the default tokenizer will be used for the underlying ML engine.

Note

When providing a config, the source of training data is not necessary, only the checkpoint_dir is needed. Each batch will control its input training data path after it creates the training dataset.

batch_size: int

The max number of columns allowed for a single DF batch

batch_to_df(batch_idx: int) DataFrame

Extract a synthetic data DataFrame from a single batch.

Parameters:

batch_idx – The batch number

Returns:

A DataFrame with synthetic data

batches: Dict[int, Batch]

A mapping of Batch objects to a batch number. The batch number (key) increments from 0..N where N is the number of batches being used.

batches_to_df() DataFrame

Convert all batches to a single synthetic data DataFrame.

Returns:

A single DataFrame that is the concatenation of all the batch DataFrames.

config: dict | BaseConfig

The template config that will be used for all batches. If a dict is provided we default to a TensorFlowConfig.

create_training_data()

Split the original DataFrame into N smaller DataFrames. Each smaller DataFrame will have the same number of rows, but a subset of the columns from the original DataFrame.

This method iterates over each Batch object and assigns a smaller training DataFrame to the training_df attribute of the object.

Finally, a training CSV is written to disk in the specific batch directory

generate_all_batch_lines(max_invalid=1000, raise_on_failed_batch: bool = False, num_lines: int | None = None, seed_fields: dict | List[dict] | None = None, parallelism: int = 0) Dict[int, GenerationSummary]

Generate synthetic lines for all batches. Lines for each batch are added to the individual Batch objects. Once generateion is done, you may re-assemble the dataset into a DataFrame.

Example:

my_batch.generate_all_batch_lines()
# Wait for all generation to complete
synthetic_df = my_batch.batches_to_df()
Parameters:
  • max_invalid – The number of invalid lines, per batch. If this number is exceeded for any batch, generation will stop.

  • raise_on_failed_batch – If True, then an exception will be raised if any single batch fails to generate the requested number of lines. If False, then the failed batch will be set to False in the result dictionary from this method.

  • num_lines

    The number of lines to create from each batch. If None then the value from the config template will be used.

    Note

    Will be overridden / ignored if seed_fields is a list. Will be set to the len of the list.

  • seed_fields

    A dictionary that maps field/column names to initial seed values for those columns. This seed will only apply to the first batch that gets trained and generated. Additionally, the fields provided in the mapping MUST exist at the front of the first batch.

    Note

    This param may also be a list of dicts. If this is the case, then num_lines will automatically be set to the list length downstream, and a 1:1 ratio will be used for generating valid lines for each prefix.

  • parallelism – The number of concurrent workers to use. 1 (the default) disables parallelization, while a non-positive value means “number of CPUs + x” (i.e., use 0 for using as many workers as there are CPUs). A floating-point value is interpreted as a fraction of the available CPUs, rounded down.

Returns:

A dictionary of batch number to a dictionary that reports the number of valid, invalid lines and bool value that shows if each batch was able to generate the full number of requested lines:

{
    0: GenerationSummary(valid_lines=1000, invalid_lines=10, is_valid=True),
    1: GenerationSummary(valid_lines=500, invalid_lines=5, is_valid=True)
}

generate_batch_lines(batch_idx: int, max_invalid=1000, raise_on_exceed_invalid: bool = False, num_lines: int | None = None, seed_fields: dict | List[dict] | None = None, parallelism: int = 0) GenerationSummary

Generate lines for a single batch. Lines generated are added to the underlying Batch object for each batch. The lines can be accessed after generation and re-assembled into a DataFrame.

Parameters:
  • batch_idx – The batch number

  • max_invalid – The max number of invalid lines that can be generated, if this is exceeded, generation will stop

  • raise_on_exceed_invalid – If true and if the number of lines generated exceeds the max_invalid amount, we will re-raise the error thrown by the generation module which will interrupt the running process. Otherwise, we will not raise the caught exception and just return False indicating that the batch failed to generate all lines.

  • num_lines – The number of lines to generate, if None, then we use the number from the batch’s config

  • seed_fields

    A dictionary that maps field/column names to initial seed values for those columns. This seed will only apply to the first batch that gets trained and generated. Additionally, the fields provided in the mapping MUST exist at the front of the first batch.

    Note

    This param may also be a list of dicts. If this is the case, then num_lines will automatically be set to the list length downstream, and a 1:1 ratio will be used for generating valid lines for each prefix.

  • parallelism – The number of concurrent workers to use. 1 (the default) disables parallelization, while a non-positive value means “number of CPUs + x” (i.e., use 0 for using as many workers as there are CPUs). A floating-point value is interpreted as a fraction of the available CPUs, rounded down.

master_header_list: List[str]

During training, this is the original column order. When reading from disk, we concatenate all headers from all batches together. This list is not guaranteed to preserve the original header order.

original_headers: List[str]

Stores the original header list / order from the original training data that was used. This is written out to the model directory during training and loaded back in when using read-only mode.

set_batch_validator(batch_idx: int, validator: Callable)

Set a validator for a specific batch. If a validator is configured for a batch, each generated record from that batch will be sent to the validator.

Parameters:
  • batch_idx – The batch number .

  • validator – A callable that should take exactly one argument, which will be the raw line generated from the generate_text function.

train_all_batches()

Train a model for each batch.

train_batch(batch_idx: int)

Train a model for a single batch. All model information will be written into that batch’s directory.

Parameters:

batch_idx – The index of the batch, from the batches dictionary

class gretel_synthetics.batch.GenerationProgress(current_valid_count: int = 0, current_invalid_count: int = 0, new_valid_count: int = 0, new_invalid_count: int = 0, completion_percent: float = 0.0, timestamp: float = <factory>)

This class should not have to be used directly.

It is used to communicate the current progress of record generation.

When a callback function is passed to the RecordFactory.generate_all() method, each time the callback is called an instance of this class will be passed as the single argument:

def my_callback(data: GenerationProgress):
    ...

factory: RecordFactory
df = factory.generate_all(output="df", callback=my_callback)

This class is used to periodically communicate progress of generation to the user, through a callback that can be passed to RecordFactory.generate_all() method.

completion_percent: float = 0.0

The percentage of valid lines/records that have been generated.

current_invalid_count: int = 0

The number of invalid lines/records that were generated so far.

current_valid_count: int = 0

The number of valid lines/records that were generated so far.

new_invalid_count: int = 0

The number of new valid lines/records that were generated since the last progress callback.

new_valid_count: int = 0

The number of new valid lines/records that were generated since the last progress callback.

timestamp: float

The timestamp from when the information in this object has been captured.

class gretel_synthetics.batch.GenerationResult(records: pandas.core.frame.DataFrame | List[dict], exception: Exception | None = None)
class gretel_synthetics.batch.GenerationSummary(valid_lines: int = 0, invalid_lines: int = 0, is_valid: bool = False)

A class to capture the summary data after synthetic data is generated.

class gretel_synthetics.batch.RecordFactory(*, num_lines: int, batches: dict, header_list: list, delimiter: str, seed_fields: dict | list | None = None, max_invalid=1000, validator: Callable | None = None, parallelism: int = 4, invalid_cache_size: int = 100)

A stateful factory that can be used to generate and validate entire records, regardless of the number of underlying header clusters that were used to build multiple sub-models.

Instances of this class should be created by calling the appropiate method of the DataFrameBatch instance. This class should not have to be used directly. You should be able to create an instance like so:

factory = batcher.create_record_factory(num_lines=50)

The class is init’d with default capacity and limits as specified by the num_lines and max_invalid attributes. At any time, you can inspect the state of the instance by doing:

factory.summary

The factory instance can be used one of two ways: buffered or unbuffered.

For unbuffered mode, the entire instance can be used as an iterator to create synthetic records. Each record will be a dictionary.

Note

All values in the generated dictionaries will be strings.

The valid_count and invalid_count counters will update as records are generated.

When creating the record factory, you may also provide an entire record validator:

def validator(rec: dict):
    ...

factory = batcher.create_record_factory(num_lines=50, validator=validator)

Each generated record dict will be passed to the validator. This validator may either return False or raise an exception to mark a record as invalid.

At any point, you may reset the state of the factory by calling:

factory.reset()

This will reset all counters and allow you to keep generating records.

Finally, you can generate records in buffered mode, where generated records will be buffered in memory and returned as one collection. By default, a list of dicts will be returned:

factory.generate_all()

You may request the records to be returned as a DataFrame. The dtypes will be inferred as if you were reading the data from a CSV:

factory.generate_all(output="df")

Note

When using generate_all, the factory states will be reset automatically.

generate_all(output: str | None = None, callback: callable | None = None, callback_interval: int = 30, callback_threading: bool = False) GenerationResult

Attempt to generate the full number of records that was set when creating the RecordFactory. This method will create a buffer that holds all records and then returns the the buffer once generation is complete.

Parameters:
  • output – How the records should be returned. If None, which is the default, then a list of record dicts will be returned. Other options that are supported are: ‘df’ for a DataFrame.

  • callback – An optional callable that will periodically be called with a GenerationProgress instance as the single argument while records are being generated.

  • callback_interval – If using a callback, the minimum number of seconds that should occur between callbacks.

  • callback_threading – If enabled, a watchdog thread will be used to execute the callback. This will ensure that the callback is called regardless of invalid or valid counts. If callback threading is disabled, the callback will only be called after valid records are generated. If the callback raises and exception, then a threading event will be set which will trigger the stopping of generation.

Returns:

Generated records in an object that is dependent on the output param. By default this will be a list of dicts.

validator: Callable

An optional callable that will receive a fully constructed record for one final validation before returning or yielding a single record. Records that do not pass this validation will also increment the invalid_count.