Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion slime/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

from slime.backends.sglang_utils.sglang_config import ModelConfig, ServerGroupConfig, SglangConfig
from slime.backends.sglang_utils.sglang_engine import SGLangEngine
from slime.rollout.base_types import call_rollout_fn
from slime.rollout.base_types import apply_rollout_sample_filter, call_rollout_fn
from slime.utils import logging_utils
from slime.utils.dp_schedule import build_dp_schedule
from slime.utils.health_monitor import RolloutHealthMonitor
Expand Down Expand Up @@ -601,6 +601,7 @@ def _get_rollout_data(self, rollout_id):
# set the same rollout_id on every sibling so the loss reducer counts
# the rollout once instead of N times.
_validate_rollout_id_annotated(data)
apply_rollout_sample_filter(self.args, data)
# flatten the data if it is a list of lists
while isinstance(data[0], list):
data = list(itertools.chain.from_iterable(data))
Expand Down
11 changes: 11 additions & 0 deletions slime/rollout/base_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,14 @@ def call_rollout_fn(fn, *args, evaluation: bool, **kwargs):
output = RolloutFnEvalOutput(data=output) if evaluation else RolloutFnTrainOutput(samples=output)

return output


def apply_rollout_sample_filter(args, samples: list[Any]) -> None:
"""Apply the rollout sample filter to grouped rollout samples in place."""
if args.rollout_sample_filter_path is None:
return

from slime.utils.misc import load_function

filter_func = load_function(args.rollout_sample_filter_path)
filter_func(args, samples)
6 changes: 2 additions & 4 deletions slime/rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,11 +475,9 @@ async def generate_rollout_async(

# reset the global state to prevent effects on the next rollout or eval.
state.reset()
if args.rollout_sample_filter_path is not None:
filter_func = load_function(args.rollout_sample_filter_path)
filter_func(args, data)

# There can be circumstances where users want to process all samples including filtered ones.
# RolloutManager applies rollout_sample_filter_path after rollout functions return.
# This hook stays here because it needs all samples, including filtered ones.
if args.rollout_all_samples_process_path is not None:
process_func = load_function(args.rollout_all_samples_process_path)
process_func(args, all_samples, data_source)
Expand Down
30 changes: 29 additions & 1 deletion tests/plugin_contracts/test_plugin_rollout_contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@
DEFAULT_ROLLOUT_FUNCTION_PATH = "slime.rollout.sglang_rollout.generate_rollout"
REFERENCE_ROLLOUT_FUNCTION_PATH = "plugin_contracts.test_plugin_rollout_contracts.valid_rollout_function"

from slime.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput, call_rollout_fn
from slime.rollout.base_types import (
RolloutFnEvalOutput,
RolloutFnTrainOutput,
apply_rollout_sample_filter,
call_rollout_fn,
)
from slime.rollout.sglang_rollout import generate_rollout as default_generate_rollout
from slime.utils.misc import load_function
from slime.utils.types import Sample
Expand Down Expand Up @@ -81,6 +86,11 @@ def valid_rollout_function(args, rollout_id, data_source, evaluation=False):
return RolloutFnTrainOutput(samples=groups, metrics={"source": "contract"})


def drop_last_sample_filter(args, groups: list[list[Sample]]) -> None:
for group in groups:
group[-1].remove_sample = True


def invalid_rollout_function(args, rollout_id, data_source, evaluation=False):
sample = make_sample(0)
sample.reward = None
Expand Down Expand Up @@ -180,6 +190,24 @@ def test_local_rollout_plugin_aligns_with_default_input_output_format():
assert_rollout_function_matches_default_contract(valid_rollout_function)


def test_rollout_sample_filter_applies_to_custom_rollout_output():
args = type(
"Args",
(),
{
"rollout_sample_filter_path": "plugin_contracts.test_plugin_rollout_contracts.drop_last_sample_filter",
},
)()
train_output = call_rollout_fn(valid_rollout_function, args, 2, ContractDataSource(), evaluation=False)

apply_rollout_sample_filter(args, train_output.samples)

assert [[sample.remove_sample for sample in group] for group in train_output.samples] == [
[False, True],
[False, True],
]


def test_misaligned_rollout_plugin_is_rejected():
with pytest.raises(AssertionError):
assert_rollout_function_matches_default_contract(invalid_rollout_function)
Expand Down
Loading