Train confidence model
Description
QWEN, like most Large Language Models (LLMs), tends to be overconfident in its predictions. However, having reliable confident scores is crucial to assess prediction quality on unlabelled samples.
This script enables training an external confidence estimation model using features extracted from QWEN’s outputs. This model is trained to estimate the output quality (1 - CER).
The model is composed of an ensemble of four models:
Model training
Use the teklia-qwen-scripts train-confidence command to train a confidence estimator based on QWEN features.
| Parameter | Description | Type | Default |
|---|---|---|---|
|
Path to CSV datasets |
|
|
|
Path where the model will be saved |
|
|
|
Whether to apply data augmentation |
|
|
Requirements
Installation
To use this command, make sure to install the required dependencies by running pip install -e .[confidence].
This will install the following packages:
-
Pandas to load and manipulate datasets.
-
sklearn-onnx to export scikit-learn models to ONNX.
Data preparation
Before running the command, ensure that you have extracted features from QWEN and saved them as CSV files named train.csv and test.csv. Each CSV file should contain at least:
-
A
"target"column representing the target score (1 - CER) ; -
Additional columns representing extracted features.
Example CSV format:
image_h,image_w,image_mean_pixel,image_std_pixel,output_mean_softmax,output_std_softmax,output_mean_top2,output_std_top2,output_length,target
1120.0,2000.0,208.60644330357144,16.19615444882727,0.9816937515610142,0.07268391945398575,0.9680195166488557,0.12647240237776525,278.0,0.9568345323741008
672.0,2000.0,181.49323511904763,19.164813028126623,0.9951708530124865,0.030643435174466792,0.9906568195145824,0.059549330852306216,283.0,0.9823321554770318
646.0,2000.0,183.73134210526317,17.488067033559506,0.9884472043689237,0.07493327136158669,0.9826821252662437,0.10793178706749104,315.0,0.9841269841269842
1120.0,2000.0,208.53210401785714,14.106565994368221,0.9867921951570009,0.07147337880228842,0.9787936390046726,0.12191978277770296,280.0,0.9892857142857143
2000.0,1463.0,173.8266917293233,27.000648946505837,0.9927339029836131,0.030929712734256685,0.9869337497849896,0.0543110395457957,257.0,0.8823529411764706