Supporting New Inferencers
Inferencer is the core component in AISBench responsible for executing model inference. It adopts different inference methods according to different model types (API models or local models). Before adapting a new inferencer, it is recommended to first refer to the definition methods of prompt_template and meta_template to understand how AISBench constructs prompts.
Currently, AISBench supports the following inferencer types:
GenInferencer: Inferencer for generative tasks, supporting API models and local models
MultiTurnGenInferencer: Inferencer for multi-turn dialogue tasks, supporting API models and local models
PPLInferencer: Inferencer for Perplexity evaluation
For certain special inference scenarios or custom requirements, it is usually necessary to implement custom inferencers. According to the model type being called, inferencers need to implement different interfaces:
API Models: Need to implement the
do_requestasync method, calling service models through HTTP requestsLocal Models: Need to implement the
batch_inferencesync method, directly calling local models for batch inference
Adding API Model Inferencers
To add an inferencer based on API models, create a new file my_custom_api_inferencer.py in ais_bench/benchmark/openicl/icl_inferencer, inherit from BaseApiInferencer, and implement the corresponding functional interfaces according to usage scenarios. The currently supported extensible interfaces are as follows:
(Required)
do_request: Execute a single inference request, used for async inference of API models(Required)
get_data_list: Get data list from retriever, used to construct inference data
from multiprocessing import BoundedSemaphore
from typing import List, Optional
import uuid
import copy
import aiohttp
from ais_bench.benchmark.models.output import RequestOutput
from ais_bench.benchmark.registry import ICL_INFERENCERS
from ais_bench.benchmark.openicl.icl_retriever import BaseRetriever
from ais_bench.benchmark.openicl.icl_inferencer.icl_base_api_inferencer import BaseApiInferencer
from ais_bench.benchmark.openicl.icl_inferencer.output_handler.gen_inferencer_output_handler import GenInferencerOutputHandler
@ICL_INFERENCERS.register_module()
class MyCustomApiInferencer(BaseApiInferencer):
"""Custom API model inferencer class.
Attributes:
model_cfg: Model configuration
batch_size (:obj:`int`, optional): Batch size
output_json_filepath (:obj:`str`, optional): Output JSON file path
save_every (:obj:`int`, optional): Save intermediate results every N samples
"""
def __init__(
self,
model_cfg,
batch_size: Optional[int] = 1,
mode: Optional[str] = "infer",
output_json_filepath: Optional[str] = "./icl_inference_output",
save_every: Optional[int] = 1,
**kwargs,
) -> None:
super().__init__(
model_cfg=model_cfg,
batch_size=batch_size,
mode=mode,
output_json_filepath=output_json_filepath,
save_every=save_every,
**kwargs,
)
# Initialize output handler
self.output_handler = GenInferencerOutputHandler(
perf_mode=self.perf_mode,
save_every=self.save_every
)
async def do_request(
self,
data: dict,
token_bucket: BoundedSemaphore,
session: aiohttp.ClientSession
):
"""Execute a single inference request.
Args:
data: Dictionary containing request data, usually includes the following fields:
- prompt: Input prompt
- index: Data index
- data_abbr: Dataset identifier
- max_out_len: Maximum output length
- gold: Ground truth (optional)
token_bucket: Semaphore for rate limiting
session: HTTP session object
"""
data = copy.deepcopy(data)
index = data.pop("index")
input = data.pop("prompt")
data_abbr = data.pop("data_abbr")
max_out_len = data.pop("max_out_len")
gold = data.pop("gold", None)
# Generate unique identifier
uid = str(uuid.uuid4()).replace("-", "")
output = RequestOutput(self.perf_mode)
output.uuid = uid
# Update status counter
await self.status_counter.post()
# Call model for inference
await self.model.generate(input, max_out_len, output, session=session, **data)
# Update status
if output.success:
await self.status_counter.rev()
else:
await self.status_counter.failed()
await self.status_counter.finish()
await self.status_counter.case_finish()
# Report results to output handler
await self.output_handler.report_cache_info(index, input, output, data_abbr, gold)
def get_data_list(
self,
retriever: BaseRetriever,
) -> List:
"""Get data list from retriever.
Args:
retriever: Retriever instance, used to get data and generate prompts
Returns:
Data list, each element is a dictionary containing information needed for inference
"""
data_abbr = retriever.dataset.abbr
ice_idx_list = retriever.retrieve()
prompt_list = []
# Generate prompt for each sample
for idx, ice_idx in enumerate(ice_idx_list):
ice = retriever.generate_ice(ice_idx)
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice,
gen_field_replace_token=self.gen_field_replace_token if hasattr(self, 'gen_field_replace_token') else "",
)
# Parse template
parsed_prompt = self.model.parse_template(prompt, mode="gen")
prompt_list.append(parsed_prompt)
self.logger.info(f"Apply ice template finished")
# Get ground truth
gold_ans = retriever.get_gold_ans()
# Build data list
data_list = []
for index, prompt in enumerate(prompt_list):
data_list.append(
{
"prompt": prompt,
"data_abbr": data_abbr,
"index": index,
"max_out_len": self.model.max_out_len,
}
)
# Add ground truth
if gold_ans is not None:
for index, gold in enumerate(gold_ans):
data_list[index]["gold"] = gold
# Dataset-specified max_out_len has highest priority
max_out_lens = retriever.dataset_reader.get_max_out_len()
if max_out_lens is not None:
self.logger.warning("Dataset-specified max_out_len has highest priority, use dataset-specified max_out_len")
for index, max_out_len in enumerate(max_out_lens):
data_list[index]["max_out_len"] = max_out_len if max_out_len else self.model.max_out_len
return data_list
It is recommended to add the new inferencer class to __init__.py for convenient automatic import later.
For detailed implementation, refer to: GenInferencer
Adding Local Model Inferencers
To add an inferencer based on local models, create a new file my_custom_local_inferencer.py in ais_bench/benchmark/openicl/icl_inferencer, inherit from BaseLocalInferencer, and implement the corresponding functional interfaces according to usage scenarios. The currently supported extensible interfaces are as follows:
(Required)
batch_inference: Execute batch inference, used for sync inference of local models(Required)
get_data_list: Get data list from retriever, used to construct inference data
from typing import List, Optional
from torch.utils.data import DataLoader
from ais_bench.benchmark.registry import ICL_INFERENCERS
from ais_bench.benchmark.openicl.icl_retriever import BaseRetriever
from ais_bench.benchmark.openicl.icl_inferencer.icl_base_local_inferencer import BaseLocalInferencer
from ais_bench.benchmark.openicl.icl_inferencer.output_handler.gen_inferencer_output_handler import GenInferencerOutputHandler
@ICL_INFERENCERS.register_module()
class MyCustomLocalInferencer(BaseLocalInferencer):
"""Custom local model inferencer class.
Attributes:
model_cfg: Model configuration
batch_size (:obj:`int`, optional): Batch size
output_json_filepath (:obj:`str`, optional): Output JSON file path
save_every (:obj:`int`, optional): Save intermediate results every N samples
"""
def __init__(
self,
model_cfg,
batch_size: Optional[int] = 1,
output_json_filepath: Optional[str] = "./icl_inference_output",
save_every: Optional[int] = 1,
**kwargs,
) -> None:
super().__init__(
model_cfg=model_cfg,
batch_size=batch_size,
output_json_filepath=output_json_filepath,
)
self.save_every = save_every
# Initialize output handler
self.output_handler = GenInferencerOutputHandler(
perf_mode=False, # Local inferencers usually don't support performance mode
save_every=self.save_every
)
def batch_inference(
self,
datum: dict,
) -> None:
"""Execute batch inference.
Args:
datum: Dictionary containing batch data, usually includes the following fields:
- prompt: Input prompt list
- index: Data index list
- data_abbr: Dataset identifier list
- max_out_len: Maximum output length list
- gold: Ground truth list (optional)
"""
indexs = datum.pop("index")
inputs = datum.pop("prompt")
data_abbrs = datum.pop("data_abbr")
max_out_lens = datum.pop("max_out_len")
golds = datum.pop("gold", [None] * len(inputs))
# Call local model for batch inference
# Local models use unified max_out_len from model configuration
outputs = self.model.generate(inputs, self.model.max_out_len, **datum)
# Process each output result
for index, input, output, data_abbr, gold in zip(
indexs, inputs, outputs, data_abbrs, golds
):
self.output_handler.report_cache_info_sync(
index, input, output, data_abbr, gold
)
def get_data_list(
self,
retriever: BaseRetriever,
) -> List:
"""Get data list from retriever.
Args:
retriever: Retriever instance, used to get data and generate prompts
Returns:
Data list, each element is a dictionary containing information needed for inference
"""
data_abbr = retriever.dataset.abbr
ice_idx_list = retriever.retrieve()
prompt_list = []
# Generate prompt for each sample
for idx, ice_idx in enumerate(ice_idx_list):
ice = retriever.generate_ice(ice_idx)
prompt = retriever.generate_prompt_for_generate_task(
idx,
ice,
gen_field_replace_token=self.gen_field_replace_token if hasattr(self, 'gen_field_replace_token') else "",
)
# Parse template
parsed_prompt = self.model.parse_template(prompt, mode="gen")
prompt_list.append(parsed_prompt)
self.logger.info(f"Apply ice template finished")
# Get ground truth
gold_ans = retriever.get_gold_ans()
# Build data list
data_list = []
for index, prompt in enumerate(prompt_list):
data_list.append(
{
"prompt": prompt,
"data_abbr": data_abbr,
"index": index,
"max_out_len": self.model.max_out_len,
}
)
# Add ground truth
if gold_ans is not None:
for index, gold in enumerate(gold_ans):
data_list[index]["gold"] = gold
# Dataset-specified max_out_len has highest priority
max_out_lens = retriever.dataset_reader.get_max_out_len()
if max_out_lens is not None:
self.logger.warning("Dataset-specified max_out_len has highest priority, use dataset-specified max_out_len")
for index, max_out_len in enumerate(max_out_lens):
data_list[index]["max_out_len"] = max_out_len if max_out_len else self.model.max_out_len
return data_list
It is recommended to add the new inferencer class to __init__.py for convenient automatic import later.
For detailed implementation, refer to: GenInferencer
Inferencers Supporting Both API Models and Local Models
If an inferencer needs to support both API models and local models, it can inherit from both BaseApiInferencer and BaseLocalInferencer, and implement the required methods of both base classes. This way, the same inferencer class can be used for both types of models.
from ais_bench.benchmark.registry import ICL_INFERENCERS
from ais_bench.benchmark.openicl.icl_inferencer.icl_base_api_inferencer import BaseApiInferencer
from ais_bench.benchmark.openicl.icl_inferencer.icl_base_local_inferencer import BaseLocalInferencer
@ICL_INFERENCERS.register_module()
class MyCustomInferencer(BaseApiInferencer, BaseLocalInferencer):
"""Custom inferencer supporting both API models and local models.
This class inherits from both BaseApiInferencer and BaseLocalInferencer,
and needs to implement the required methods of both base classes.
"""
def __init__(
self,
model_cfg,
batch_size: Optional[int] = 1,
mode: Optional[str] = "infer",
output_json_filepath: Optional[str] = "./icl_inference_output",
save_every: Optional[int] = 1,
**kwargs,
) -> None:
# Call initialization methods of both base classes
BaseApiInferencer.__init__(
self,
model_cfg=model_cfg,
batch_size=batch_size,
mode=mode,
output_json_filepath=output_json_filepath,
save_every=save_every,
**kwargs,
)
# Initialize output handler
self.output_handler = GenInferencerOutputHandler(
perf_mode=self.perf_mode,
save_every=self.save_every
)
async def do_request(self, data: dict, token_bucket: BoundedSemaphore, session: aiohttp.ClientSession):
"""API model inference method (required)"""
# Implement API model inference logic
pass
def batch_inference(self, datum: dict) -> None:
"""Local model inference method (required)"""
# Implement local model inference logic
pass
def get_data_list(self, retriever: BaseRetriever) -> List:
"""Get data list (required)"""
# Implement data list retrieval logic
pass
For detailed implementation, refer to: GenInferencer
Using Custom Inferencers in Configuration Files
After defining a custom inferencer, it needs to be used in the dataset configuration file. In the corresponding configuration file under ais_bench/benchmark/configs/datasets, set the inferencer type in infer_cfg to the custom inferencer class:
from ais_bench.benchmark.openicl.icl_inferencer import MyCustomInferencer
from ais_bench.benchmark.openicl.icl_prompt_template import PromptTemplate
from ais_bench.benchmark.openicl.icl_retriever import ZeroRetriever
# Inference configuration
mydataset_infer_cfg = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(
role="HUMAN",
prompt="{question}\nRemember to put your final answer within \\boxed{}.",
),
],
),
),
retriever=dict(type=ZeroRetriever), # Retriever configuration
inferencer=dict(type=MyCustomInferencer), # Custom inferencer configuration
)
# Dataset configuration list
mydataset_datasets = [
dict(
type=MyDataset, # Custom dataset class name
# ... Other dataset initialization parameters ...
reader_cfg=mydataset_reader_cfg, # Dataset reading configuration
infer_cfg=mydataset_infer_cfg, # Inference configuration (including custom inferencer)
eval_cfg=mydataset_eval_cfg # Accuracy evaluation configuration
)
]
Notes
Registration Decorator: Custom inferencers must use the
@ICL_INFERENCERS.register_module()decorator for registration to be recognized by the configuration system.Output Handler: Choose an appropriate output handler according to actual needs. Commonly used ones include:
GenInferencerOutputHandler: For output processing of generative tasksPPLInferencerOutputHandler: For output processing of perplexity evaluation
Status Management: For API model inferencers, note:
Use
status_counterto track request status (post, rev, failed, finish, case_finish)Correctly update status counter in the
do_requestmethod
Error Handling: Exception cases should be properly handled during inference to ensure output results include error information for subsequent analysis and debugging.
Performance Mode: If the inferencer needs to support performance evaluation (
mode="perf"), ensure:API model inferencers must implement the
parse_stream_responseinterface (in the model class)Correctly set the
perf_modeflagUse
RequestOutputto save performance-related metrics
Data Format: Each dictionary in the data list returned by the
get_data_listmethod must contain the following required fields:prompt: Input promptindex: Data indexdata_abbr: Dataset identifiermax_out_len: Maximum output lengthgold: Ground truth (optional)