Skip to content

vllm.kernels.helion.register

vLLM Helion kernel registration with pre-tuned config selection.

This module leverages Helion's internal config selection infrastructure to use pre-tuned configs instead of runtime autotuning.

How Helion Normally Works

For each kernel invocation, Helion: 1. Computes a cache key from input arguments 2. Looks up the key in its internal compilation cache 3. On cache miss, runs autotuning to find the best config 4. Compiles and caches the kernel with that config

How We Override It

We override two Helion hooks to use pre-tuned configs:

  1. key: We provide a key function (derived from config_picker) that computes cache keys matching our pre-tuned config keys. This ensures Helion's internal cache uses keys that correspond to configs we've prepared.

  2. autotuner_fn: We provide PresetConfigSearch which, instead of autotuning, simply returns the pre-tuned config for the computed key. On cache miss, Helion calls our autotuner which returns the author-prepared config.

Both hooks use the same config_picker logic to ensure the cache key computed by key matches the config returned by the autotuner.

Key Classes

  • HelionKernelWrapper: Wraps raw kernel + config_picker, creates configured kernels
  • ConfiguredHelionKernel: Platform-specific kernel with pre-tuned configs
  • PresetConfigSearch: Custom autotuner that returns pre-tuned configs

ConfiguredHelionKernel

A configured Helion kernel bound to a specific platform.

Source code in vllm/kernels/helion/register.py
class ConfiguredHelionKernel:
    """A configured Helion kernel bound to a specific platform."""

    def __init__(
        self,
        op_name: str,
        config_picker: ConfigPicker | None,
        raw_kernel_func: Callable,
        helion_settings: helion.Settings | None = None,
    ):
        self.op_name = op_name
        self.config_picker = config_picker
        self.raw_kernel_func = raw_kernel_func
        self.helion_settings = helion_settings
        self._decorated_kernel = self._create_decorated_kernel()

    def __call__(self, *args, **kwargs):
        return self._decorated_kernel(*args, **kwargs)

    def _create_key_computer(self):
        """
        Create a key computer function derived from the config picker.

        The returned function receives kernel arguments unpacked (*args) to match
        Helion's key signature (called as self._key_fn(*args)).
        """
        if self.config_picker is None:
            raise RuntimeError(
                f"No config picker registered for kernel '{self.op_name}'. "
                f"A config_picker must be provided to register_kernel()."
            )

        picker = self.config_picker
        all_keys = list(self.configs.keys())
        default = CaseKey.default()
        has_default = default in self.configs

        def key_computer(*args):
            selected = picker(args, all_keys)
            if selected is not None:
                return str(selected)
            if has_default:
                return str(default)
            return None

        return key_computer

    def _create_config_selector(self, key_computer):
        str_to_key = {str(k): k for k in self.configs}

        def config_selector(args):
            selected_str = key_computer(*args)

            if selected_str is None:
                raise ValueError(
                    f"Config picker returned None for kernel "
                    f"'{self.op_name}' with available config keys: "
                    f"{list(self.configs.keys())}"
                )

            config_key = str_to_key.get(selected_str)
            if config_key is None:
                raise ValueError(
                    f"Config picker returned invalid config key "
                    f"'{selected_str}' for kernel "
                    f"'{self.op_name}'. "
                    f"Available keys: {list(self.configs.keys())}"
                )

            return self.configs[config_key]

        return config_selector

    def _load_platform_configs(self) -> None:
        from vllm.kernels.helion.config_manager import ConfigManager
        from vllm.kernels.helion.utils import get_canonical_gpu_name

        self.platform = get_canonical_gpu_name()
        config_manager = ConfigManager()
        self.configs = config_manager.get_platform_configs(self.op_name, self.platform)

        if not self.configs:
            raise ValueError(
                f"No configs available for kernel '{self.op_name}' "
                f"on platform '{self.platform}'"
            )

    def _create_decorated_kernel(self) -> Callable[..., Any]:
        self._load_platform_configs()

        key_computer = self._create_key_computer()
        config_selector = self._create_config_selector(key_computer)

        extra_kwargs = {
            "autotuner_fn": lambda _, args: PresetConfigSearch(args, config_selector),
            "key": key_computer,
        }

        logger.debug(
            "Creating decorated kernel %s with custom autotuner on platform %s",
            self.op_name,
            self.platform,
        )
        return create_helion_decorated_kernel(
            self.raw_kernel_func, self.helion_settings, extra_kwargs
        )

_create_key_computer

_create_key_computer()

Create a key computer function derived from the config picker.

The returned function receives kernel arguments unpacked (args) to match Helion's key signature (called as self._key_fn(args)).

Source code in vllm/kernels/helion/register.py
def _create_key_computer(self):
    """
    Create a key computer function derived from the config picker.

    The returned function receives kernel arguments unpacked (*args) to match
    Helion's key signature (called as self._key_fn(*args)).
    """
    if self.config_picker is None:
        raise RuntimeError(
            f"No config picker registered for kernel '{self.op_name}'. "
            f"A config_picker must be provided to register_kernel()."
        )

    picker = self.config_picker
    all_keys = list(self.configs.keys())
    default = CaseKey.default()
    has_default = default in self.configs

    def key_computer(*args):
        selected = picker(args, all_keys)
        if selected is not None:
            return str(selected)
        if has_default:
            return str(default)
        return None

    return key_computer

HelionKernelWrapper

Wrapper for Helion kernels with pre-tuned config selection and HOP support.

Source code in vllm/kernels/helion/register.py
class HelionKernelWrapper:
    """Wrapper for Helion kernels with pre-tuned config selection and HOP support."""

    def __init__(
        self,
        raw_kernel_func: Callable,
        op_name: str,
        fake_impl: Callable,
        config_picker: ConfigPicker,
        helion_settings: helion.Settings | None = None,
        input_generator: (Callable[[], dict[CaseKey, tuple[Any, ...]]] | None) = None,
    ):
        # Validate helion_settings doesn't conflict with our custom autotuner
        validate_helion_settings(helion_settings, op_name)

        self.raw_kernel_func = raw_kernel_func
        self.op_name = op_name
        self._fake_impl = fake_impl
        self.helion_settings = helion_settings
        self._config_picker = config_picker
        self._input_generator = input_generator
        self._configured_kernel: ConfiguredHelionKernel | None = None
        # TODO(@gmagogsfm): Remove this disable flag once integrated with vLLM IR,
        # which handles op enablement/disablement.
        self._disabled = False
        self._disabled_reason: str | None = None

        try:
            if not _HOP_AVAILABLE:
                self._get_or_register_custom_op()
            else:
                self.get_configured_op()
        except ValueError as e:
            self._disabled = True
            self._disabled_reason = str(e)
            logger.warning(
                "Helion kernel '%s' is disabled: %s",
                op_name,
                self._disabled_reason,
            )

    def __call__(self, *args, **kwargs):
        if self._disabled:
            raise RuntimeError(
                f"Helion kernel '{self.op_name}' is disabled: {self._disabled_reason}"
            )
        if not _HOP_AVAILABLE:
            op = getattr(torch.ops.vllm_helion, self.op_name)
            return op(*args, **kwargs)
        assert self._configured_kernel is not None, (
            f"Kernel '{self.op_name}' was not initialized. "
            "Please open an issue on GitHub."
        )

        # During Dynamo tracing, this call will be intercepted by our custom
        # HelionKernelWrapperVariable and handled via proper HOP emission.
        # During eager execution, call the kernel directly.
        return self._configured_kernel(*args, **kwargs)

    def get_inputs(self) -> dict[CaseKey, tuple[Any, ...]]:
        if self._input_generator is None:
            raise NotImplementedError(
                f"No input generator registered for kernel '{self.op_name}'. "
                f"Use register_kernel(..., input_generator=...) to register one."
            )
        return self._input_generator()

    def run_autotune(
        self,
        inputs: tuple[Any, ...],
        autotune_effort: str = "quick",
    ) -> Config:
        """Run autotuning for a single input configuration."""
        extra_kwargs = {
            "autotune_effort": autotune_effort,
            "autotune_ignore_errors": True,
        }
        autotune_kernel = create_helion_decorated_kernel(
            self.raw_kernel_func, self.helion_settings, extra_kwargs
        )
        return autotune_kernel.autotune(inputs)

    def get_configured_op(self) -> ConfiguredHelionKernel:
        if self._disabled:
            raise RuntimeError(
                f"Helion kernel '{self.op_name}' is disabled: {self._disabled_reason}"
            )
        if self._configured_kernel is None:
            self._configured_kernel = ConfiguredHelionKernel(
                op_name=self.op_name,
                config_picker=self._config_picker,
                raw_kernel_func=self.raw_kernel_func,
                helion_settings=self.helion_settings,
            )
        return self._configured_kernel

    def _get_or_register_custom_op(self) -> Any:
        if hasattr(torch.ops.vllm_helion, self.op_name):
            return getattr(torch.ops.vllm_helion, self.op_name)

        configured_kernel = self.get_configured_op()

        logger.info("Registering op: vllm_helion::%s", self.op_name)
        direct_register_custom_op(
            op_name=self.op_name,
            op_func=configured_kernel._decorated_kernel,
            mutates_args=None,
            fake_impl=self._fake_impl,
            target_lib=vllm_helion_lib,
        )
        return getattr(torch.ops.vllm_helion, self.op_name)

run_autotune

run_autotune(
    inputs: tuple[Any, ...], autotune_effort: str = "quick"
) -> Config

Run autotuning for a single input configuration.

Source code in vllm/kernels/helion/register.py
def run_autotune(
    self,
    inputs: tuple[Any, ...],
    autotune_effort: str = "quick",
) -> Config:
    """Run autotuning for a single input configuration."""
    extra_kwargs = {
        "autotune_effort": autotune_effort,
        "autotune_ignore_errors": True,
    }
    autotune_kernel = create_helion_decorated_kernel(
        self.raw_kernel_func, self.helion_settings, extra_kwargs
    )
    return autotune_kernel.autotune(inputs)

PresetConfigSearch

Bases: BaseAutotuner

Custom autotuner that uses a preset config selector instead of autotuning.

Source code in vllm/kernels/helion/register.py
class PresetConfigSearch(BaseAutotuner):
    """Custom autotuner that uses a preset config selector instead of autotuning."""

    def __init__(
        self,
        args: tuple[Any, ...],
        config_selector: Callable[[tuple[Any, ...]], Config],
    ):
        self.args = args
        self.config_selector = config_selector

    def autotune(self, *, skip_cache: bool = False) -> Config:
        return self.config_selector(self.args)

_register_vllm_helion_dynamo_variable

_register_vllm_helion_dynamo_variable()

Register HelionKernelWrapper with Dynamo's VariableBuilder.

When Dynamo encounters a HelionKernelWrapper during tracing, this extracts the underlying Helion Kernel and delegates to Helion's own registered Kernel handler, which handles HOP emission, side table registration, and inductor lowering setup.

Source code in vllm/kernels/helion/register.py
def _register_vllm_helion_dynamo_variable():
    """Register HelionKernelWrapper with Dynamo's VariableBuilder.

    When Dynamo encounters a HelionKernelWrapper during tracing, this
    extracts the underlying Helion Kernel and delegates to Helion's own
    registered Kernel handler, which handles HOP emission, side table
    registration, and inductor lowering setup.
    """

    def wrap_helion_kernel_wrapper(
        builder: VariableBuilder, value: HelionKernelWrapper
    ):
        kernel = value.get_configured_op()._decorated_kernel
        if supports_torch_compile_fusion():
            helion_handler = VariableBuilder._type_dispatch()[Kernel]
            return helion_handler(builder, kernel)
        kernel_idx = helion_kernel_side_table.add_kernel(kernel)
        builder.install_guards(GuardBuilder.ID_MATCH)
        return HelionKernelVariable(kernel, kernel_idx, source=builder.source)

    dispatch = VariableBuilder._type_dispatch()
    dispatch[HelionKernelWrapper] = wrap_helion_kernel_wrapper

register_kernel

register_kernel(
    op_name: str | None = None,
    *,
    config_picker: ConfigPicker,
    fake_impl: Callable | None = None,
    helion_settings: Settings | None = None,
    input_generator: Callable[
        [], dict[CaseKey, tuple[Any, ...]]
    ]
    | None = None,
) -> Callable[[Callable], HelionKernelWrapper]

Register a Helion kernel with pre-tuned config selection.

Parameters:

Name Type Description Default
config_picker ConfigPicker

Required. Receives (args, config_keys) where each config key is a dict[str, Any] mapping parameter names to values. Return the best-matching dict, or None to fall back to the default config.

Example::

def pick_config(args, config_keys):
    x = args[0]
    best = min(config_keys, key=lambda k: abs(k["size"] - x.shape[0]))
    return best
required
input_generator Callable[[], dict[CaseKey, tuple[Any, ...]]] | None

Optional. Returns dict[str, tuple] where each key is a serialized config key and each value is a tuple of arguments to pass to the kernel.

Example::

def generate_inputs():
    return {
        "4096": (torch.randn(4096, device="cuda"), 0.5),
        "8192": (torch.randn(8192, device="cuda"), 0.5),
    }
None
Source code in vllm/kernels/helion/register.py
def register_kernel(
    op_name: str | None = None,
    *,
    config_picker: ConfigPicker,
    fake_impl: Callable | None = None,
    helion_settings: helion.Settings | None = None,
    input_generator: (Callable[[], dict[CaseKey, tuple[Any, ...]]] | None) = None,
) -> Callable[[Callable], HelionKernelWrapper]:
    """Register a Helion kernel with pre-tuned config selection.

    Args:
        config_picker: Required. Receives ``(args, config_keys)``
            where each config key is a ``dict[str, Any]`` mapping
            parameter names to values.  Return the best-matching
            dict, or ``None`` to fall back to the default config.

            Example::

                def pick_config(args, config_keys):
                    x = args[0]
                    best = min(config_keys, key=lambda k: abs(k["size"] - x.shape[0]))
                    return best

        input_generator: Optional. Returns ``dict[str, tuple]`` where
            each key is a serialized config key and each value is a
            tuple of arguments to pass to the kernel.

            Example::

                def generate_inputs():
                    return {
                        "4096": (torch.randn(4096, device="cuda"), 0.5),
                        "8192": (torch.randn(8192, device="cuda"), 0.5),
                    }
    """

    def decorator(kernel_func: Callable) -> HelionKernelWrapper:
        final_op_name = op_name if op_name else kernel_func.__name__

        if final_op_name in _REGISTERED_KERNELS:
            raise ValueError(
                f"Helion kernel '{final_op_name}' is already registered. "
                f"Use a different op_name or check for duplicate registrations."
            )

        final_fake_impl = fake_impl
        if final_fake_impl is None:
            final_fake_impl = infer_fake_impl(kernel_func, helion_settings)
            logger.debug(
                "Auto-generated fake_impl for Helion kernel '%s'",
                kernel_func.__name__,
            )

        kernel_wrapper = HelionKernelWrapper(
            raw_kernel_func=kernel_func,
            op_name=final_op_name,
            fake_impl=final_fake_impl,
            config_picker=config_picker,
            helion_settings=helion_settings,
            input_generator=input_generator,
        )

        _REGISTERED_KERNELS[final_op_name] = kernel_wrapper

        logger.info(
            "Registered Helion kernel '%s' as HelionKernelWrapper",
            kernel_func.__name__,
        )

        return kernel_wrapper

    return decorator