Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 26 additions & 1 deletion src/google/adk/cli/adk_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,28 @@ def _get_root_agent(self, agent_or_app: BaseAgent | App) -> BaseAgent:
return agent_or_app.root_agent
return agent_or_app

def _get_effective_modalities(
self, root_agent: BaseAgent, requested_modalities: List[str]
) -> List[str]:
"""Determines effective modalities, forcing AUDIO for native-audio models.

Native-audio models only support AUDIO modality. This method detects
native-audio models by checking if the model name contains "native-audio"
and forces AUDIO modality for those models.

Args:
root_agent: The root agent of the application.
requested_modalities: The modalities requested by the client.

Returns:
The effective modalities to use.
"""
model = getattr(root_agent, "model", None)
model_name = model if isinstance(model, str) else ""
Comment on lines +569 to +570
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current logic for extracting the model name only handles the case where the model attribute is a string. The LlmAgent.model attribute can also be a BaseLlm object, in which case isinstance(model, str) would be false, model_name would become an empty string, and the check for "native-audio" would fail.

To make this more robust, you should also handle the case where model is an object (like BaseLlm) that has a model string attribute. It would also be beneficial to add a test case for an LlmAgent initialized with a BaseLlm object to ensure full coverage.

Suggested change
model = getattr(root_agent, "model", None)
model_name = model if isinstance(model, str) else ""
model = getattr(root_agent, "model", None)
model_name = ""
if isinstance(model, str):
model_name = model
elif hasattr(model, "model") and isinstance(getattr(model, "model"), str):
model_name = getattr(model, "model")

if "native-audio" in model_name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The string "native-audio" is used here to identify specific models and is also present in the test files. To improve maintainability and prevent potential inconsistencies if this identifier changes, consider defining it as a module-level constant.

For example:

_NATIVE_AUDIO_MODEL_TAG = "native-audio"

This would allow you to reference _NATIVE_AUDIO_MODEL_TAG in both the implementation and the tests, making the code more robust and easier to update.

return ["AUDIO"]
return requested_modalities

def _create_runner(self, agentic_app: App) -> Runner:
"""Create a runner with common services."""
return Runner(
Expand Down Expand Up @@ -1652,7 +1674,10 @@ async def run_agent_live(

async def forward_events():
runner = await self.get_runner_async(app_name)
run_config = RunConfig(response_modalities=modalities)
effective_modalities = self._get_effective_modalities(
runner.app.root_agent, modalities
)
run_config = RunConfig(response_modalities=effective_modalities)
async with Aclosing(
runner.run_live(
session=session,
Expand Down
87 changes: 87 additions & 0 deletions tests/unittests/cli/test_fast_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

from fastapi.testclient import TestClient
from google.adk.agents.base_agent import BaseAgent
from google.adk.agents.llm_agent import LlmAgent
from google.adk.agents.run_config import RunConfig
from google.adk.apps.app import App
from google.adk.artifacts.base_artifact_service import ArtifactVersion
Expand Down Expand Up @@ -1411,5 +1412,91 @@ def test_builder_save_rejects_traversal(builder_test_client, tmp_path):
assert not (tmp_path / "app" / "tmp" / "escape.yaml").exists()


def test_native_audio_model_forces_audio_modality():
"""Test that native-audio models force AUDIO modality."""
from google.adk.cli.adk_web_server import AdkWebServer

native_audio_agent = LlmAgent(
name="native_audio_agent",
model="gemini-live-2.5-flash-native-audio",
)

adk_web_server = AdkWebServer(
agent_loader=MagicMock(),
session_service=MagicMock(),
memory_service=MagicMock(),
artifact_service=MagicMock(),
credential_service=MagicMock(),
eval_sets_manager=MagicMock(),
eval_set_results_manager=MagicMock(),
agents_dir=".",
)
Comment on lines +1424 to +1433
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The instantiation of AdkWebServer with mocked services is duplicated across the three new tests (test_native_audio_model_forces_audio_modality, test_non_native_audio_model_keeps_requested_modality, and test_agent_without_model_attribute). To improve code clarity and maintainability, you could extract this common setup into a pytest fixture.

Here's an example of what that fixture could look like:

@pytest.fixture
def adk_web_server_for_modality_tests():
    """Provides an AdkWebServer instance with mocked services for modality tests."""
    from google.adk.cli.adk_web_server import AdkWebServer
    return AdkWebServer(
        agent_loader=MagicMock(),
        session_service=MagicMock(),
        memory_service=MagicMock(),
        artifact_service=MagicMock(),
        credential_service=MagicMock(),
        eval_sets_manager=MagicMock(),
        eval_set_results_manager=MagicMock(),
        agents_dir=".",
    )

Each test could then accept adk_web_server_for_modality_tests as an argument, which would make the tests cleaner and reduce code duplication.


# Test: requesting TEXT should be forced to AUDIO
modalities = adk_web_server._get_effective_modalities(
native_audio_agent, ["TEXT"]
)
assert modalities == ["AUDIO"]

# Test: requesting AUDIO should stay AUDIO
modalities = adk_web_server._get_effective_modalities(
native_audio_agent, ["AUDIO"]
)
assert modalities == ["AUDIO"]


def test_non_native_audio_model_keeps_requested_modality():
"""Test that non-native-audio models keep the requested modality."""
from google.adk.cli.adk_web_server import AdkWebServer

regular_agent = LlmAgent(
name="regular_agent",
model="gemini-2.5-flash",
)

adk_web_server = AdkWebServer(
agent_loader=MagicMock(),
session_service=MagicMock(),
memory_service=MagicMock(),
artifact_service=MagicMock(),
credential_service=MagicMock(),
eval_sets_manager=MagicMock(),
eval_set_results_manager=MagicMock(),
agents_dir=".",
)

# Test: requesting TEXT should stay TEXT
modalities = adk_web_server._get_effective_modalities(regular_agent, ["TEXT"])
assert modalities == ["TEXT"]

# Test: requesting AUDIO should stay AUDIO
modalities = adk_web_server._get_effective_modalities(
regular_agent, ["AUDIO"]
)
assert modalities == ["AUDIO"]


def test_agent_without_model_attribute():
"""Test that agents without model attribute keep requested modality."""
from google.adk.cli.adk_web_server import AdkWebServer

base_agent = DummyAgent(name="base_agent")

adk_web_server = AdkWebServer(
agent_loader=MagicMock(),
session_service=MagicMock(),
memory_service=MagicMock(),
artifact_service=MagicMock(),
credential_service=MagicMock(),
eval_sets_manager=MagicMock(),
eval_set_results_manager=MagicMock(),
agents_dir=".",
)

# Test: BaseAgent without model attr should keep requested modality
modalities = adk_web_server._get_effective_modalities(base_agent, ["TEXT"])
assert modalities == ["TEXT"]


if __name__ == "__main__":
pytest.main(["-xvs", __file__])
Loading