dataquality.loggers.model_logger.seq2seq package#
Submodules#
dataquality.loggers.model_logger.seq2seq.chat module#
- class Seq2SeqChatModelLogger(embs=None, probs=None, logits=None, ids=None, split='', epoch=None, inference_name=None)#
Bases:
Seq2SeqModelLogger
Initialize the Seq2SeqModelLogger
In Seq2Seq if probs is passed in it is actually logprobs
-
logger_config:
Seq2SeqChatLoggerConfig
= Seq2SeqChatLoggerConfig(labels=None, tasks=None, observed_num_labels=None, observed_labels=None, tagging_schema=None, last_epoch=0, cur_epoch=None, cur_split=None, cur_inference_name=None, training_logged=False, validation_logged=False, test_logged=False, inference_logged=False, exception='', helper_data={}, input_data_logged=defaultdict(<class 'int'>, {}), logged_input_ids=defaultdict(<class 'set'>, {}), idx_to_id_map=defaultdict(<class 'list'>, {}), conditions=[], report_emails=[], ner_labels=[], int_labels=False, feature_names=[], metadata_documents=set(), finish=<function BaseLoggerConfig.<lambda>>, existing_run=False, dataloader_random_sampling=False, remove_embs=False, sample_length={}, tokenizer=None, max_input_tokens=None, max_target_tokens=None, id_to_tokens=defaultdict(<class 'dict'>, {}), model=None, generation_config=None, generation_splits=set(), model_type=None, id_to_formatted_prompt_length=defaultdict(<class 'dict'>, {}), response_template=None)#
-
logger_config:
dataquality.loggers.model_logger.seq2seq.completion module#
- class Seq2SeqCompletionModelLogger(embs=None, probs=None, logits=None, ids=None, split='', epoch=None, inference_name=None)#
Bases:
Seq2SeqModelLogger
Initialize the Seq2SeqModelLogger
In Seq2Seq if probs is passed in it is actually logprobs
-
logger_config:
Seq2SeqCompletionLoggerConfig
= Seq2SeqCompletionLoggerConfig(labels=None, tasks=None, observed_num_labels=None, observed_labels=None, tagging_schema=None, last_epoch=0, cur_epoch=None, cur_split=None, cur_inference_name=None, training_logged=False, validation_logged=False, test_logged=False, inference_logged=False, exception='', helper_data={}, input_data_logged=defaultdict(<class 'int'>, {}), logged_input_ids=defaultdict(<class 'set'>, {}), idx_to_id_map=defaultdict(<class 'list'>, {}), conditions=[], report_emails=[], ner_labels=[], int_labels=False, feature_names=[], metadata_documents=set(), finish=<function BaseLoggerConfig.<lambda>>, existing_run=False, dataloader_random_sampling=False, remove_embs=False, sample_length={}, tokenizer=None, max_input_tokens=None, max_target_tokens=None, id_to_tokens=defaultdict(<class 'dict'>, {}), model=None, generation_config=None, generation_splits=set(), model_type=None, id_to_formatted_prompt_length=defaultdict(<class 'dict'>, {}), response_template=None)#
-
logger_config:
dataquality.loggers.model_logger.seq2seq.formatters module#
- class BaseSeq2SeqModelFormatter(logger_config)#
Bases:
ABC
- abstract format_sample(sample_id, sample_output_tokens, split_key, shift_labels=True)#
Formats sample_output_tokens before extracting token information
- Depending on the model architecture this function:
Removes padding tokens from model outputs
Restricts to just the response / target tokens
Note: shift_labels is only used for DecoderOnly models. See further details in the DecoderOnly definition.
- Returns:
- np.ndarray
Used for extracting token logprob data
formatted_sample_output_tokens: np.ndarray
- Return type:
formatted_labels
- retrieve_sample_labels(sample_id, max_tokens, split_key)#
Retrieve the labels array based on the sample id and truncate at max_tokens
Labels gives the ground truth / target sample ids for each token in the sequence:
e.g. for sample_id = 8 –> labels = [0, 10, 16, …]
- Return type:
ndarray
- class EncoderDecoderModelFormatter(logger_config)#
Bases:
BaseSeq2SeqModelFormatter
Seq2Seq model logger for EncoderDecoder models
Since Encoder-Decoder models output logits just over the target tokens, there is very little additional processing - i.e. we primarily leverage functionality from Seq2SeqModelLogger.
- format_sample(sample_id, sample_output_tokens, split_key, shift_labels=True)#
Formats sample_output_tokens by removing padding tokens
- Return type:
Tuple
[ndarray
,ndarray
]
- class DecoderOnlyModelFormatter(logger_config)#
Bases:
BaseSeq2SeqModelFormatter
Seq2Seq model logger for EncoderDecoder models
Since Encoder-Decoder models output logits just over the target tokens, there is very little additional processing - i.e. we primarily leverage functionality from Seq2SeqModelLogger.
- format_sample(sample_id, sample_output_tokens, split_key, shift_labels=True)#
Formats sample_output_tokens
- Return type:
Tuple
[ndarray
,ndarray
]
- Actions taken:
Removes padding tokens based off of the length of the tokenized
formatted prompt - Restricts to just response tokens using the saved response_labels
The shift_labels flag is used to align the ‘logits’ / ‘logprobs’ with the Response Token Labels. As a general rule:
When logging directly from non-api models (e.g. hf), the response_labels
are “shifted” right by one from the logits. Thus, to align them - i.e. get the correct logits for each token label - we need to account for this shift.
e.g. formatted_sample_ids = [1, 2, 3, 4, 5, 6, 7, 8] response_tokens_ids = [6, 7, 8] logits = shape[8, vocab]
# Output corresponsing to model input tokens [5, 6, 7] response_logits = logits[-4: -1] # NOT response_logits = logits[-3:]
When logging from an api, the logits or logprobs are generally aligned for
us. Therefore, we don’t need to account for this right shift.
- get_model_formatter(model_type, logger_config)#
Returns the model formatter for the given model_type
- Return type:
dataquality.loggers.model_logger.seq2seq.seq2seq_base module#
- class Seq2SeqModelLogger(embs=None, probs=None, logits=None, ids=None, split='', epoch=None, inference_name=None)#
Bases:
BaseGalileoModelLogger
Seq2Seq base model logger
This class defines the base functionality for logging model outputs in Seq2Seq tasks - shared between EncoderDecoder and DecoderOnly architectures.
After architecture specific processing of raw model logits, we leverage a shared function for processing and extracting the logprob token data just over the Target data.
- During processing, the following key information is extracted:
token_logprobs: log-probs for GT tokens in each sample
top_logprobs: top-K (token_str, log-prob) pairs for each token
Initialize the Seq2SeqModelLogger
In Seq2Seq if probs is passed in it is actually logprobs
-
logger_config:
Seq2SeqLoggerConfig
= Seq2SeqLoggerConfig(labels=None, tasks=None, observed_num_labels=None, observed_labels=None, tagging_schema=None, last_epoch=0, cur_epoch=None, cur_split=None, cur_inference_name=None, training_logged=False, validation_logged=False, test_logged=False, inference_logged=False, exception='', helper_data={}, input_data_logged=defaultdict(<class 'int'>, {}), logged_input_ids=defaultdict(<class 'set'>, {}), idx_to_id_map=defaultdict(<class 'list'>, {}), conditions=[], report_emails=[], ner_labels=[], int_labels=False, feature_names=[], metadata_documents=set(), finish=<function BaseLoggerConfig.<lambda>>, existing_run=False, dataloader_random_sampling=False, remove_embs=False, sample_length={}, tokenizer=None, max_input_tokens=None, max_target_tokens=None, id_to_tokens=defaultdict(<class 'dict'>, {}), model=None, generation_config=None, generation_splits=set(), model_type=None, id_to_formatted_prompt_length=defaultdict(<class 'dict'>, {}), response_template=None)#
- log_file_ext = 'arrow'#
- property split_key: str#
- validate_and_format()#
Validate the lengths, calculate token level dep, extract GT probs
- Return type:
None
- process_logits(batch_ids, batch_logits)#
Process a batch of sample logit data
- For each sample in the batch extract / compute the following values:
Token level logprobs for the GT label
Token level top-k model logprobs: represented as a dictionary
mapping {predicted_token: logprob}
batch_logits has shape - [batch_size, max_token_length, vocab_size], where max_token_length is determined by the longest sample in the batch. Because other samples in the batch are padded to this max_length, we have to process each sample individually to ignore pad token indices.
- Special points of consideration:
For each sample, top-k logprobs is a list of dictionaries with length
equal to the number of tokens in that sample. Each dictionary maps the models top-k predicted tokens to their corresponding logprobs.
We return a pyarrow array because each sample may have a different number
of token, which can’t be represented in numpy.
- Returns:
- GT Logprob per token
len(batch_token_dep) == batch_size batch_token_logprobs[i].shape is [num_tokens_in_label[i]]
- batch_top_logprobs: Top-k logprob dictionary per token
type(batch_top_logprobs[i]) = List[Dict[str, float]] len(batch_top_logprobs) == batch_size len(batch_top_logprobs[i]) = num_tokens_in_label[i]
- Return type:
batch_token_logprobs
- process_logprobs(batch_ids, batch_logprobs)#
Process a batch of sample logprob data
This is a special case where the use only logs a single logprobs for each token - i.e. the label token’s logprob.
batch_logprobs.shape = [bs, max_token_length]
In this case, we do not have any top_k logprob data; therefore, we fill the top_logprob data with “filler” data. Each token’s top 5 logprob data is:
[(”—”, -20)] * TOP_K
Similar to process_logits we process the logprob data to remove 1) remove padding and 2) apply any other formatting to just restrict to token level information for the “Target” tokens.
- Special points of consideration:
We return a pyarrow array because each sample may have a different number
of token, which can’t be represented in numpy.
- Returns:
- GT Logprob per token
len(batch_token_dep) == batch_size batch_token_logprobs[i].shape is [num_tokens_in_label[i]]
- batch_top_logprobs: Top-k logprob dictionary per token
type(batch_top_logprobs[i]) = List[Dict[str, float]] len(batch_top_logprobs) == batch_size len(batch_top_logprobs[i]) = num_tokens_in_label[i] batch_top_logprobs[i][0] = (”—”, -20)
- Return type:
batch_token_logprobs
- convert_logits_to_logprobs(sample_logits)#
Converts logits (unnormalized log probabilities) to logprobs via log_softmax
This is a special use case for Seq2Seq, people generally work with logprobs. One reason for this is that the logsoftmax function takes advantage of the logsumexp “trick” to compute a numerically stable version of log(softmax(x)).
- Return type:
ndarray