Supporting New Inferencers

Inferencer is the core component in AISBench responsible for executing model inference. It adopts different inference methods according to different model types (API models or local models). Before adapting a new inferencer, it is recommended to first refer to the definition methods of prompt_template and meta_template to understand how AISBench constructs prompts.

Currently, AISBench supports the following inferencer types:

  • GenInferencer: Inferencer for generative tasks, supporting API models and local models

  • MultiTurnGenInferencer: Inferencer for multi-turn dialogue tasks, supporting API models and local models

  • PPLInferencer: Inferencer for Perplexity evaluation

For certain special inference scenarios or custom requirements, it is usually necessary to implement custom inferencers. According to the model type being called, inferencers need to implement different interfaces:

  • API Models: Need to implement the do_request async method, calling service models through HTTP requests

  • Local Models: Need to implement the batch_inference sync method, directly calling local models for batch inference

Adding API Model Inferencers

To add an inferencer based on API models, create a new file my_custom_api_inferencer.py in ais_bench/benchmark/openicl/icl_inferencer, inherit from BaseApiInferencer, and implement the corresponding functional interfaces according to usage scenarios. The currently supported extensible interfaces are as follows:

  • (Required) do_request: Execute a single inference request, used for async inference of API models

  • (Required) get_data_list: Get data list from retriever, used to construct inference data

from multiprocessing import BoundedSemaphore
from typing import List, Optional
import uuid
import copy
import aiohttp

from ais_bench.benchmark.models.output import RequestOutput
from ais_bench.benchmark.registry import ICL_INFERENCERS
from ais_bench.benchmark.openicl.icl_retriever import BaseRetriever
from ais_bench.benchmark.openicl.icl_inferencer.icl_base_api_inferencer import BaseApiInferencer
from ais_bench.benchmark.openicl.icl_inferencer.output_handler.gen_inferencer_output_handler import GenInferencerOutputHandler


@ICL_INFERENCERS.register_module()
class MyCustomApiInferencer(BaseApiInferencer):
    """Custom API model inferencer class.

    Attributes:
        model_cfg: Model configuration
        batch_size (:obj:`int`, optional): Batch size
        output_json_filepath (:obj:`str`, optional): Output JSON file path
        save_every (:obj:`int`, optional): Save intermediate results every N samples
    """

    def __init__(
        self,
        model_cfg,
        batch_size: Optional[int] = 1,
        mode: Optional[str] = "infer",
        output_json_filepath: Optional[str] = "./icl_inference_output",
        save_every: Optional[int] = 1,
        **kwargs,
    ) -> None:
        super().__init__(
            model_cfg=model_cfg,
            batch_size=batch_size,
            mode=mode,
            output_json_filepath=output_json_filepath,
            save_every=save_every,
            **kwargs,
        )

        # Initialize output handler
        self.output_handler = GenInferencerOutputHandler(
            perf_mode=self.perf_mode,
            save_every=self.save_every
        )

    async def do_request(
        self,
        data: dict,
        token_bucket: BoundedSemaphore,
        session: aiohttp.ClientSession
    ):
        """Execute a single inference request.

        Args:
            data: Dictionary containing request data, usually includes the following fields:
                - prompt: Input prompt
                - index: Data index
                - data_abbr: Dataset identifier
                - max_out_len: Maximum output length
                - gold: Ground truth (optional)
            token_bucket: Semaphore for rate limiting
            session: HTTP session object
        """
        data = copy.deepcopy(data)
        index = data.pop("index")
        input = data.pop("prompt")
        data_abbr = data.pop("data_abbr")
        max_out_len = data.pop("max_out_len")
        gold = data.pop("gold", None)

        # Generate unique identifier
        uid = str(uuid.uuid4()).replace("-", "")
        output = RequestOutput(self.perf_mode)
        output.uuid = uid

        # Update status counter
        await self.status_counter.post()

        # Call model for inference
        await self.model.generate(input, max_out_len, output, session=session, **data)

        # Update status
        if output.success:
            await self.status_counter.rev()
        else:
            await self.status_counter.failed()
        await self.status_counter.finish()
        await self.status_counter.case_finish()

        # Report results to output handler
        await self.output_handler.report_cache_info(index, input, output, data_abbr, gold)

    def get_data_list(
        self,
        retriever: BaseRetriever,
    ) -> List:
        """Get data list from retriever.

        Args:
            retriever: Retriever instance, used to get data and generate prompts

        Returns:
            Data list, each element is a dictionary containing information needed for inference
        """
        data_abbr = retriever.dataset.abbr
        ice_idx_list = retriever.retrieve()
        prompt_list = []

        # Generate prompt for each sample
        for idx, ice_idx in enumerate(ice_idx_list):
            ice = retriever.generate_ice(ice_idx)
            prompt = retriever.generate_prompt_for_generate_task(
                idx,
                ice,
                gen_field_replace_token=self.gen_field_replace_token if hasattr(self, 'gen_field_replace_token') else "",
            )
            # Parse template
            parsed_prompt = self.model.parse_template(prompt, mode="gen")
            prompt_list.append(parsed_prompt)

        self.logger.info(f"Apply ice template finished")

        # Get ground truth
        gold_ans = retriever.get_gold_ans()

        # Build data list
        data_list = []
        for index, prompt in enumerate(prompt_list):
            data_list.append(
                {
                    "prompt": prompt,
                    "data_abbr": data_abbr,
                    "index": index,
                    "max_out_len": self.model.max_out_len,
                }
            )

        # Add ground truth
        if gold_ans is not None:
            for index, gold in enumerate(gold_ans):
                data_list[index]["gold"] = gold

        # Dataset-specified max_out_len has highest priority
        max_out_lens = retriever.dataset_reader.get_max_out_len()
        if max_out_lens is not None:
            self.logger.warning("Dataset-specified max_out_len has highest priority, use dataset-specified max_out_len")
            for index, max_out_len in enumerate(max_out_lens):
                data_list[index]["max_out_len"] = max_out_len if max_out_len else self.model.max_out_len

        return data_list

It is recommended to add the new inferencer class to __init__.py for convenient automatic import later.

For detailed implementation, refer to: GenInferencer

Adding Local Model Inferencers

To add an inferencer based on local models, create a new file my_custom_local_inferencer.py in ais_bench/benchmark/openicl/icl_inferencer, inherit from BaseLocalInferencer, and implement the corresponding functional interfaces according to usage scenarios. The currently supported extensible interfaces are as follows:

  • (Required) batch_inference: Execute batch inference, used for sync inference of local models

  • (Required) get_data_list: Get data list from retriever, used to construct inference data

from typing import List, Optional
from torch.utils.data import DataLoader

from ais_bench.benchmark.registry import ICL_INFERENCERS
from ais_bench.benchmark.openicl.icl_retriever import BaseRetriever
from ais_bench.benchmark.openicl.icl_inferencer.icl_base_local_inferencer import BaseLocalInferencer
from ais_bench.benchmark.openicl.icl_inferencer.output_handler.gen_inferencer_output_handler import GenInferencerOutputHandler


@ICL_INFERENCERS.register_module()
class MyCustomLocalInferencer(BaseLocalInferencer):
    """Custom local model inferencer class.

    Attributes:
        model_cfg: Model configuration
        batch_size (:obj:`int`, optional): Batch size
        output_json_filepath (:obj:`str`, optional): Output JSON file path
        save_every (:obj:`int`, optional): Save intermediate results every N samples
    """

    def __init__(
        self,
        model_cfg,
        batch_size: Optional[int] = 1,
        output_json_filepath: Optional[str] = "./icl_inference_output",
        save_every: Optional[int] = 1,
        **kwargs,
    ) -> None:
        super().__init__(
            model_cfg=model_cfg,
            batch_size=batch_size,
            output_json_filepath=output_json_filepath,
        )

        self.save_every = save_every

        # Initialize output handler
        self.output_handler = GenInferencerOutputHandler(
            perf_mode=False,  # Local inferencers usually don't support performance mode
            save_every=self.save_every
        )

    def batch_inference(
        self,
        datum: dict,
    ) -> None:
        """Execute batch inference.

        Args:
            datum: Dictionary containing batch data, usually includes the following fields:
                - prompt: Input prompt list
                - index: Data index list
                - data_abbr: Dataset identifier list
                - max_out_len: Maximum output length list
                - gold: Ground truth list (optional)
        """
        indexs = datum.pop("index")
        inputs = datum.pop("prompt")
        data_abbrs = datum.pop("data_abbr")
        max_out_lens = datum.pop("max_out_len")
        golds = datum.pop("gold", [None] * len(inputs))

        # Call local model for batch inference
        # Local models use unified max_out_len from model configuration
        outputs = self.model.generate(inputs, self.model.max_out_len, **datum)

        # Process each output result
        for index, input, output, data_abbr, gold in zip(
            indexs, inputs, outputs, data_abbrs, golds
        ):
            self.output_handler.report_cache_info_sync(
                index, input, output, data_abbr, gold
            )

    def get_data_list(
        self,
        retriever: BaseRetriever,
    ) -> List:
        """Get data list from retriever.

        Args:
            retriever: Retriever instance, used to get data and generate prompts

        Returns:
            Data list, each element is a dictionary containing information needed for inference
        """
        data_abbr = retriever.dataset.abbr
        ice_idx_list = retriever.retrieve()
        prompt_list = []

        # Generate prompt for each sample
        for idx, ice_idx in enumerate(ice_idx_list):
            ice = retriever.generate_ice(ice_idx)
            prompt = retriever.generate_prompt_for_generate_task(
                idx,
                ice,
                gen_field_replace_token=self.gen_field_replace_token if hasattr(self, 'gen_field_replace_token') else "",
            )
            # Parse template
            parsed_prompt = self.model.parse_template(prompt, mode="gen")
            prompt_list.append(parsed_prompt)

        self.logger.info(f"Apply ice template finished")

        # Get ground truth
        gold_ans = retriever.get_gold_ans()

        # Build data list
        data_list = []
        for index, prompt in enumerate(prompt_list):
            data_list.append(
                {
                    "prompt": prompt,
                    "data_abbr": data_abbr,
                    "index": index,
                    "max_out_len": self.model.max_out_len,
                }
            )

        # Add ground truth
        if gold_ans is not None:
            for index, gold in enumerate(gold_ans):
                data_list[index]["gold"] = gold

        # Dataset-specified max_out_len has highest priority
        max_out_lens = retriever.dataset_reader.get_max_out_len()
        if max_out_lens is not None:
            self.logger.warning("Dataset-specified max_out_len has highest priority, use dataset-specified max_out_len")
            for index, max_out_len in enumerate(max_out_lens):
                data_list[index]["max_out_len"] = max_out_len if max_out_len else self.model.max_out_len

        return data_list

It is recommended to add the new inferencer class to __init__.py for convenient automatic import later.

For detailed implementation, refer to: GenInferencer

Inferencers Supporting Both API Models and Local Models

If an inferencer needs to support both API models and local models, it can inherit from both BaseApiInferencer and BaseLocalInferencer, and implement the required methods of both base classes. This way, the same inferencer class can be used for both types of models.

from ais_bench.benchmark.registry import ICL_INFERENCERS
from ais_bench.benchmark.openicl.icl_inferencer.icl_base_api_inferencer import BaseApiInferencer
from ais_bench.benchmark.openicl.icl_inferencer.icl_base_local_inferencer import BaseLocalInferencer


@ICL_INFERENCERS.register_module()
class MyCustomInferencer(BaseApiInferencer, BaseLocalInferencer):
    """Custom inferencer supporting both API models and local models.

    This class inherits from both BaseApiInferencer and BaseLocalInferencer,
    and needs to implement the required methods of both base classes.
    """

    def __init__(
        self,
        model_cfg,
        batch_size: Optional[int] = 1,
        mode: Optional[str] = "infer",
        output_json_filepath: Optional[str] = "./icl_inference_output",
        save_every: Optional[int] = 1,
        **kwargs,
    ) -> None:
        # Call initialization methods of both base classes
        BaseApiInferencer.__init__(
            self,
            model_cfg=model_cfg,
            batch_size=batch_size,
            mode=mode,
            output_json_filepath=output_json_filepath,
            save_every=save_every,
            **kwargs,
        )

        # Initialize output handler
        self.output_handler = GenInferencerOutputHandler(
            perf_mode=self.perf_mode,
            save_every=self.save_every
        )

    async def do_request(self, data: dict, token_bucket: BoundedSemaphore, session: aiohttp.ClientSession):
        """API model inference method (required)"""
        # Implement API model inference logic
        pass

    def batch_inference(self, datum: dict) -> None:
        """Local model inference method (required)"""
        # Implement local model inference logic
        pass

    def get_data_list(self, retriever: BaseRetriever) -> List:
        """Get data list (required)"""
        # Implement data list retrieval logic
        pass

For detailed implementation, refer to: GenInferencer

Using Custom Inferencers in Configuration Files

After defining a custom inferencer, it needs to be used in the dataset configuration file. In the corresponding configuration file under ais_bench/benchmark/configs/datasets, set the inferencer type in infer_cfg to the custom inferencer class:

from ais_bench.benchmark.openicl.icl_inferencer import MyCustomInferencer
from ais_bench.benchmark.openicl.icl_prompt_template import PromptTemplate
from ais_bench.benchmark.openicl.icl_retriever import ZeroRetriever

# Inference configuration
mydataset_infer_cfg = dict(
    prompt_template=dict(
        type=PromptTemplate,
        template=dict(
            round=[
                dict(
                    role="HUMAN",
                    prompt="{question}\nRemember to put your final answer within \\boxed{}.",
                ),
            ],
        ),
    ),
    retriever=dict(type=ZeroRetriever),      # Retriever configuration
    inferencer=dict(type=MyCustomInferencer), # Custom inferencer configuration
)

# Dataset configuration list
mydataset_datasets = [
    dict(
        type=MyDataset,                    # Custom dataset class name
        # ... Other dataset initialization parameters ...
        reader_cfg=mydataset_reader_cfg,   # Dataset reading configuration
        infer_cfg=mydataset_infer_cfg,     # Inference configuration (including custom inferencer)
        eval_cfg=mydataset_eval_cfg        # Accuracy evaluation configuration
    )
]

Notes

  1. Registration Decorator: Custom inferencers must use the @ICL_INFERENCERS.register_module() decorator for registration to be recognized by the configuration system.

  2. Output Handler: Choose an appropriate output handler according to actual needs. Commonly used ones include:

    • GenInferencerOutputHandler: For output processing of generative tasks

    • PPLInferencerOutputHandler: For output processing of perplexity evaluation

  3. Status Management: For API model inferencers, note:

    • Use status_counter to track request status (post, rev, failed, finish, case_finish)

    • Correctly update status counter in the do_request method

  4. Error Handling: Exception cases should be properly handled during inference to ensure output results include error information for subsequent analysis and debugging.

  5. Performance Mode: If the inferencer needs to support performance evaluation (mode="perf"), ensure:

    • API model inferencers must implement the parse_stream_response interface (in the model class)

    • Correctly set the perf_mode flag

    • Use RequestOutput to save performance-related metrics

  6. Data Format: Each dictionary in the data list returned by the get_data_list method must contain the following required fields:

    • prompt: Input prompt

    • index: Data index

    • data_abbr: Dataset identifier

    • max_out_len: Maximum output length

    • gold: Ground truth (optional)