Batch¶
This module allows automatic splitting of a DataFrame into smaller DataFrames (by clusters of columns) and doing model training and text generation on each sub-DF independently.
Then we can concat each sub-DF back into one final synthetic dataset.
For example usage, please see our Jupyter Notebook.
-
class
gretel_synthetics.batch.
Batch
(checkpoint_dir: str, input_data_path: str, headers: List[str], config: gretel_synthetics.config.TensorFlowConfig, gen_data_count: int = 0)¶ A representation of a synthetic data workflow. It should not be used directly. This object is created automatically by the primary batch handler, such as
DataFrameBatch
. This class holds all of the necessary information for training, data generation and DataFrame re-assembly.-
add_valid_data
(data: gretel_synthetics.generate.GenText)¶ Take a
gen_text
object and add the generated line to the generated data stream
-
get_validator
()¶ If a custom validator is set, we return that. Otherwise, we return the built-in validator, which simply checks if a generated line has the right number of values based on the number of headers for this batch.
This at least makes sure the resulting DataFrame will be the right shape
-
load_validator_from_file
()¶ Load a saved validation object if it exists
-
reset_gen_data
()¶ Reset all objects that accumulate or track synthetic data generation
-
set_validator
(fn: Callable, save=True)¶ Assign a validation callable to this batch. Optionally pickling and saving the validator for loading later
-
property
synthetic_df
¶ Get a DataFrame constructed from the generated lines
-
-
class
gretel_synthetics.batch.
DataFrameBatch
(*, df: pandas.core.frame.DataFrame = None, batch_size: int = 15, batch_headers: List[List[str]] = None, config: Union[dict, gretel_synthetics.config.BaseConfig] = None, tokenizer: gretel_synthetics.tokenizers.BaseTokenizerTrainer = None, mode: str = 'write', checkpoint_dir: str = None)¶ Create a multi-batch trainer / generator. When created, the directory structure to store models and training data will automatically be created. The directory structure will be created under the “checkpoint_dir” location provided in the
config
template. There will be one directory per batch, where each directory will be called “batch_N” where N is the batch number, starting from 0.Training and generating can happen per-batch or we can loop over all batches to do both train / generation functions.
Example
When creating this object, you must explicitly create the training data from the input DataFrame before training models:
my_batch = DataFrameBatch(df=my_df, config=my_config) my_batch.create_training_data() my_batch.train_all_batches()
- Parameters
df – The input, source DataFrame
batch_size – If
batch_headers
is not provided we automatically break up the number of columns in the source DataFrame into batches of N columns.batch_headers – A list of lists of strings can be provided which will control the number of batches. The number of inner lists is the number of batches, and each inner list represents the columns that belong to that batch
config – A template training config to use, this will be used as kwargs for each Batch’s synthetic configuration. This may also be a sucblass of
BaseConfig
. If this is used, you can set theinput_data_path
param to the constantPATH_HOLDER
as it does not really mattertokenizer_class – An optional
BaseTokenizerTrainer
subclass. If not provided the default tokenizer will be used for the underlying ML engine.
Note
When providing a config, the source of training data is not necessary, only the
checkpoint_dir
is needed. Each batch will control its input training data path after it creates the training dataset.-
batch_size
: int = None¶ The max number of columns allowed for a single DF batch
-
batch_to_df
(batch_idx: int) → pandas.core.frame.DataFrame¶ Extract a synthetic data DataFrame from a single batch.
- Parameters
batch_idx – The batch number
- Returns
A DataFrame with synthetic data
-
batches
: Dict[int, Batch] = None¶ A mapping of
Batch
objects to a batch number. The batch number (key) increments from 0..N where N is the number of batches being used.
-
batches_to_df
() → pandas.core.frame.DataFrame¶ Convert all batches to a single synthetic data DataFrame.
- Returns
A single DataFrame that is the concatenation of all the batch DataFrames.
-
config
: Union[dict, BaseConfig] = None¶ The template config that will be used for all batches. If a dict is provided we default to a TensorFlowConfig.
-
create_training_data
()¶ Split the original DataFrame into N smaller DataFrames. Each smaller DataFrame will have the same number of rows, but a subset of the columns from the original DataFrame.
This method iterates over each
Batch
object and assigns a smaller training DataFrame to thetraining_df
attribute of the object.Finally, a training CSV is written to disk in the specific batch directory
-
generate_all_batch_lines
(max_invalid=1000, raise_on_failed_batch: bool = False, num_lines: int = None, seed_fields: Union[dict, List[dict]] = None, parallelism: int = 0) → Dict[int, gretel_synthetics.batch.GenerationSummary]¶ Generate synthetic lines for all batches. Lines for each batch are added to the individual
Batch
objects. Once generateion is done, you may re-assemble the dataset into a DataFrame.Example:
my_batch.generate_all_batch_lines() # Wait for all generation to complete synthetic_df = my_batch.batches_to_df()
- Parameters
max_invalid – The number of invalid lines, per batch. If this number is exceeded for any batch, generation will stop.
raise_on_failed_batch – If True, then an exception will be raised if any single batch fails to generate the requested number of lines. If False, then the failed batch will be set to
False
in the result dictionary from this method.num_lines –
The number of lines to create from each batch. If
None
then the value from the config template will be used.Note
Will be overridden / ignored if
seed_fields
is a list. Will be set to the len of the list.seed_fields –
A dictionary that maps field/column names to initial seed values for those columns. This seed will only apply to the first batch that gets trained and generated. Additionally, the fields provided in the mapping MUST exist at the front of the first batch.
Note
This param may also be a list of dicts. If this is the case, then
num_lines
will automatically be set to the list length downstream, and a 1:1 ratio will be used for generating valid lines for each prefix.parallelism – The number of concurrent workers to use.
1
(the default) disables parallelization, while a non-positive value means “number of CPUs + x” (i.e., use0
for using as many workers as there are CPUs). A floating-point value is interpreted as a fraction of the available CPUs, rounded down.
- Returns
A dictionary of batch number to a dictionary that reports the number of valid, invalid lines and bool value that shows if each batch was able to generate the full number of requested lines:
{ 0: GenerationSummary(valid_lines=1000, invalid_lines=10, is_valid=True), 1: GenerationSummary(valid_lines=500, invalid_lines=5, is_valid=True) }
-
generate_batch_lines
(batch_idx: int, max_invalid=1000, raise_on_exceed_invalid: bool = False, num_lines: int = None, seed_fields: Union[dict, List[dict]] = None, parallelism: int = 0) → gretel_synthetics.batch.GenerationSummary¶ Generate lines for a single batch. Lines generated are added to the underlying
Batch
object for each batch. The lines can be accessed after generation and re-assembled into a DataFrame.- Parameters
batch_idx – The batch number
max_invalid – The max number of invalid lines that can be generated, if this is exceeded, generation will stop
raise_on_exceed_invalid – If true and if the number of lines generated exceeds the
max_invalid
amount, we will re-raise the error thrown by the generation module which will interrupt the running process. Otherwise, we will not raise the caught exception and just returnFalse
indicating that the batch failed to generate all lines.num_lines – The number of lines to generate, if
None
, then we use the number from the batch’s configseed_fields –
A dictionary that maps field/column names to initial seed values for those columns. This seed will only apply to the first batch that gets trained and generated. Additionally, the fields provided in the mapping MUST exist at the front of the first batch.
Note
This param may also be a list of dicts. If this is the case, then
num_lines
will automatically be set to the list length downstream, and a 1:1 ratio will be used for generating valid lines for each prefix.parallelism – The number of concurrent workers to use.
1
(the default) disables parallelization, while a non-positive value means “number of CPUs + x” (i.e., use0
for using as many workers as there are CPUs). A floating-point value is interpreted as a fraction of the available CPUs, rounded down.
-
master_header_list
: List[str] = None¶ During training, this is the original column order. When reading from disk, we concatenate all headers from all batches together. This list is not guaranteed to preserve the original header order.
-
original_headers
: List[str] = None¶ Stores the original header list / order from the original training data that was used. This is written out to the model directory during training and loaded back in when using read-only mode.
-
set_batch_validator
(batch_idx: int, validator: Callable)¶ Set a validator for a specific batch. If a validator is configured for a batch, each generated record from that batch will be sent to the validator.
- Parameters
batch_idx – The batch number .
validator – A callable that should take exactly one argument, which will be the raw line generated from the
generate_text
function.
-
train_all_batches
()¶ Train a model for each batch.
-
train_batch
(batch_idx: int)¶ Train a model for a single batch. All model information will be written into that batch’s directory.
- Parameters
batch_idx – The index of the batch, from the
batches
dictionary
-
class
gretel_synthetics.batch.
GenerationProgress
(current_valid_count: int = 0, current_invalid_count: int = 0, new_valid_count: int = 0, new_invalid_count: int = 0, completion_percent: float = 0.0, timestamp: float = <factory>)¶ This class should not have to be used directly.
It is used to communicate the current progress of record generation.
When a callback function is passed to the
RecordFactory.generate_all()
method, each time the callback is called an instance of this class will be passed as the single argument:def my_callback(data: GenerationProgress): ... factory: RecordFactory df = factory.generate_all(output="df", callback=my_callback)
This class is used to periodically communicate progress of generation to the user, through a callback that can be passed to
RecordFactory.generate_all()
method.-
completion_percent
: float = 0.0¶ The percentage of valid lines/records that have been generated.
-
current_invalid_count
: int = 0¶ The number of invalid lines/records that were generated so far.
-
current_valid_count
: int = 0¶ The number of valid lines/records that were generated so far.
-
new_invalid_count
: int = 0¶ The number of new valid lines/records that were generated since the last progress callback.
-
new_valid_count
: int = 0¶ The number of new valid lines/records that were generated since the last progress callback.
-
timestamp
: float = None¶ The timestamp from when the information in this object has been captured.
-
-
class
gretel_synthetics.batch.
GenerationResult
(records: Union[pandas.core.frame.DataFrame, List[dict]], exception: Union[Exception, NoneType] = None)¶
-
class
gretel_synthetics.batch.
GenerationSummary
(valid_lines: int = 0, invalid_lines: int = 0, is_valid: bool = False)¶ A class to capture the summary data after synthetic data is generated.
-
class
gretel_synthetics.batch.
RecordFactory
(*, num_lines: int, batches: dict, header_list: list, delimiter: str, seed_fields: Union[dict, list] = None, max_invalid=1000, validator: Optional[Callable] = None, parallelism: int = 4, invalid_cache_size: int = 100)¶ A stateful factory that can be used to generate and validate entire records, regardless of the number of underlying header clusters that were used to build multiple sub-models.
Instances of this class should be created by calling the appropiate method of the
DataFrameBatch
instance. This class should not have to be used directly. You should be able to create an instance like so:factory = batcher.create_record_factory(num_lines=50)
The class is init’d with default capacity and limits as specified by the
num_lines
andmax_invalid
attributes. At any time, you can inspect the state of the instance by doing:factory.summary
The factory instance can be used one of two ways: buffered or unbuffered.
For unbuffered mode, the entire instance can be used as an iterator to create synthetic records. Each record will be a dictionary.
Note
All values in the generated dictionaries will be strings.
The
valid_count
andinvalid_count
counters will update as records are generated.When creating the record factory, you may also provide an entire record validator:
def validator(rec: dict): ... factory = batcher.create_record_factory(num_lines=50, validator=validator)
Each generated record dict will be passed to the validator. This validator may either return False or raise an exception to mark a record as invalid.
At any point, you may reset the state of the factory by calling:
factory.reset()
This will reset all counters and allow you to keep generating records.
Finally, you can generate records in buffered mode, where generated records will be buffered in memory and returned as one collection. By default, a list of dicts will be returned:
factory.generate_all()
You may request the records to be returned as a DataFrame. The dtypes will be inferred as if you were reading the data from a CSV:
factory.generate_all(output="df")
Note
When using
generate_all
, the factory states will be reset automatically.-
generate_all
(output: Optional[str] = None, callback: Optional[callable] = None, callback_interval: int = 30, callback_threading: bool = False) → gretel_synthetics.batch.GenerationResult¶ Attempt to generate the full number of records that was set when creating the
RecordFactory.
This method will create a buffer that holds all records and then returns the the buffer once generation is complete.- Parameters
output – How the records should be returned. If
None
, which is the default, then a list of record dicts will be returned. Other options that are supported are: ‘df’ for a DataFrame.callback – An optional callable that will periodically be called with a
GenerationProgress
instance as the single argument while records are being generated.callback_interval – If using a callback, the minimum number of seconds that should occur between callbacks.
callback_threading – If enabled, a watchdog thread will be used to execute the callback. This will ensure that the callback is called regardless of invalid or valid counts. If callback threading is disabled, the callback will only be called after valid records are generated. If the callback raises and exception, then a threading event will be set which will trigger the stopping of generation.
- Returns
Generated records in an object that is dependent on the
output
param. By default this will be a list of dicts.
-
validator
: Callable = None¶ An optional callable that will receive a fully constructed record for one final validation before returning or yielding a single record. Records that do not pass this validation will also increment the
invalid_count.
-