1212logger = logging .getLogger (__name__ )
1313
1414DEFAULT_DATASET_ARTIFACT = "synthetic_intent_dataset"
15+ DEFAULT_DATASET_FILENAME = "synthetic_intents.csv"
1516DEFAULT_EVALUATION_MODEL = "dual-encoder-text-embedding-1024"
1617
1718
@@ -60,21 +61,26 @@ def __init__(
6061 api_key = search_api_key ,
6162 )
6263
63- def _load_dataset_from_wandb (self , artifact_name : str ) -> pd .DataFrame :
64+ def _load_dataset_from_wandb (self , artifact_name : str , dataset_filename : str ) -> pd .DataFrame :
6465 """
6566 Load a dataset from a W&B artifact.
6667
6768 Args:
6869 artifact_name: Name of the W&B artifact
69-
70+ dataset_filename: Filename to save the dataset to
7071 Returns:
7172 DataFrame containing the dataset
7273 """
7374 artifact = wandb .use_artifact (f"{ artifact_name } :latest" )
7475 artifact_dir = artifact .download ()
75- return pd .read_csv (os .path .join (artifact_dir , "temp_dataset.csv" ))
76+ return pd .read_csv (os .path .join (artifact_dir , dataset_filename ))
7677
77- def _generate (self , dataset_artifact : str , generation_limit : int | None = None ) -> pd .DataFrame :
78+ def _generate (
79+ self ,
80+ dataset_artifact : str ,
81+ dataset_filename : str ,
82+ generation_limit : int | None = None ,
83+ ) -> pd .DataFrame :
7884 """
7985 Generate synthetic intents.
8086
@@ -88,6 +94,7 @@ def _generate(self, dataset_artifact: str, generation_limit: int | None = None)
8894 logger .info ("Generating synthetic intents..." )
8995 df = self .generator .generate (
9096 dataset_artifact = dataset_artifact ,
97+ dataset_filename = dataset_filename ,
9198 limit = generation_limit ,
9299 )
93100
@@ -97,6 +104,7 @@ def _generate(self, dataset_artifact: str, generation_limit: int | None = None)
97104 def _evaluate (
98105 self ,
99106 dataset_artifact : str ,
107+ dataset_filename : str ,
100108 evaluation_samples : int | None = None ,
101109 df : pd .DataFrame | None = None ,
102110 ) -> dict :
@@ -113,7 +121,7 @@ def _evaluate(
113121 """
114122 if df is None :
115123 logger .info (f"Loading dataset from artifact: { dataset_artifact } " )
116- df = self ._load_dataset_from_wandb (dataset_artifact )
124+ df = self ._load_dataset_from_wandb (dataset_artifact , dataset_filename )
117125
118126 # Evaluate search performance
119127 logger .info ("Evaluating search performance..." )
@@ -138,9 +146,10 @@ def _evaluate(
138146
139147 def run (
140148 self ,
149+ dataset_artifact : str ,
150+ dataset_filename : str ,
141151 generate_data : bool = False ,
142152 evaluate_data : bool = True ,
143- dataset_artifact : str = DEFAULT_DATASET_ARTIFACT ,
144153 generation_limit : int | None = None ,
145154 evaluation_samples : int | None = None ,
146155 ) -> None :
@@ -169,6 +178,7 @@ def run(
169178 "evaluation_model" : DEFAULT_EVALUATION_MODEL ,
170179 "evaluation_samples" : evaluation_samples ,
171180 "dataset_artifact" : dataset_artifact ,
181+ "dataset_filename" : dataset_filename ,
172182 },
173183 )
174184
@@ -177,12 +187,14 @@ def run(
177187 if generate_data :
178188 df = self ._generate (
179189 dataset_artifact = dataset_artifact ,
190+ dataset_filename = dataset_filename ,
180191 generation_limit = generation_limit ,
181192 )
182193
183194 if evaluate_data :
184195 self ._evaluate (
185196 dataset_artifact = dataset_artifact ,
197+ dataset_filename = dataset_filename ,
186198 evaluation_samples = evaluation_samples ,
187199 df = df ,
188200 )
@@ -199,15 +211,26 @@ def run(
199211 required = True ,
200212)
201213@click .option (
202- "--dataset" ,
214+ "--dataset-artifact " ,
203215 default = DEFAULT_DATASET_ARTIFACT ,
204216 help = "Name of the W&B dataset artifact to use" ,
205217 show_default = True ,
206218)
219+ @click .option (
220+ "--dataset-filename" ,
221+ default = DEFAULT_DATASET_FILENAME ,
222+ type = str ,
223+ help = "Filename to save the generated dataset to" ,
224+ show_default = True ,
225+ )
207226@click .option ("--generation-limit" , type = int , help = "Limit number of samples to generate" )
208227@click .option ("--evaluation-samples" , type = int , help = "Limit number of samples to evaluate" )
209228def main (
210- mode : str , dataset : str , generation_limit : int | None , evaluation_samples : int | None
229+ mode : str ,
230+ dataset_artifact : str ,
231+ generation_limit : int | None ,
232+ evaluation_samples : int | None ,
233+ dataset_filename : str ,
211234) -> None :
212235 """Main entry point for the evaluation pipeline."""
213236 # Get API keys from environment
@@ -235,9 +258,10 @@ def main(
235258
236259 # Run pipeline
237260 pipeline .run (
261+ dataset_artifact = dataset_artifact ,
262+ dataset_filename = dataset_filename ,
238263 generate_data = generate_data ,
239264 evaluate_data = evaluate_data ,
240- dataset_artifact = dataset ,
241265 generation_limit = generation_limit ,
242266 evaluation_samples = evaluation_samples ,
243267 )
0 commit comments