Skip to content
Open
65 changes: 57 additions & 8 deletions src/google/adk/models/gemini_llm_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,21 @@ async def send_realtime(self, input: RealtimeInput):
else:
raise ValueError('Unsupported input type: %s' % type(input))

def __build_full_text_response(self, text: str):
def __build_full_text_response(
self,
text: str,
grounding_metadata: types.GroundingMetadata | None = None,
interrupted: bool = False,
):
"""Builds a full text response.

The text should not partial and the returned LlmResponse is not be
partial.

Args:
text: The text to be included in the response.
grounding_metadata: Optional grounding metadata to include.
interrupted: Whether this response was interrupted.

Returns:
An LlmResponse containing the full text.
Expand All @@ -156,6 +163,8 @@ def __build_full_text_response(self, text: str):
role='model',
parts=[types.Part.from_text(text=text)],
),
grounding_metadata=grounding_metadata,
interrupted=interrupted,
)

async def receive(self) -> AsyncGenerator[LlmResponse, None]:
Expand All @@ -166,6 +175,7 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
"""

text = ''
last_grounding_metadata = None
async with Aclosing(self._gemini_session.receive()) as agen:
# TODO(b/440101573): Reuse StreamingResponseAggregator to accumulate
# partial content and emit responses as needed.
Expand All @@ -179,17 +189,38 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
)
if message.server_content:
content = message.server_content.model_turn
# Extract grounding_metadata from server_content (for VertexAiSearchTool, etc.)
grounding_metadata = message.server_content.grounding_metadata
if grounding_metadata:
last_grounding_metadata = grounding_metadata
# Warn if grounding_metadata is incomplete (has queries but no chunks)
# This helps identify backend issues with Vertex AI Search
if (
grounding_metadata.retrieval_queries
and not grounding_metadata.grounding_chunks
):
logger.warning(
'Incomplete grounding_metadata received: retrieval_queries=%s'
' but grounding_chunks is empty. This may indicate a'
' transient issue with the Vertex AI Search backend.',
grounding_metadata.retrieval_queries,
)
if content and content.parts:
llm_response = LlmResponse(
content=content, interrupted=message.server_content.interrupted
content=content,
interrupted=message.server_content.interrupted,
grounding_metadata=grounding_metadata,
)
if content.parts[0].text:
text += content.parts[0].text
llm_response.partial = True
# don't yield the merged text event when receiving audio data
elif text and not content.parts[0].inline_data:
yield self.__build_full_text_response(text)
yield self.__build_full_text_response(
text, last_grounding_metadata
)
text = ''
last_grounding_metadata = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This line prematurely resets last_grounding_metadata. The grounding metadata should persist for the entire duration of a turn and only be reset when the turn is complete, as is correctly handled in the turn_complete block.

Other parts of the code that flush buffered text (e.g., when handling tool_call or interrupted signals) correctly preserve the metadata for the remainder of the turn. Resetting it here is inconsistent with that logic and could cause grounding information to be lost before the turn ends.

To ensure consistent behavior, this line should be removed.

yield llm_response
# Note: in some cases, tool_call may arrive before
# generation_complete, causing transcription to appear after
Expand Down Expand Up @@ -266,32 +297,50 @@ async def receive(self) -> AsyncGenerator[LlmResponse, None]:
self._output_transcription_text = ''
if message.server_content.turn_complete:
if text:
yield self.__build_full_text_response(text)
yield self.__build_full_text_response(
text,
last_grounding_metadata,
interrupted=message.server_content.interrupted,
)
text = ''
yield LlmResponse(
turn_complete=True,
interrupted=message.server_content.interrupted,
grounding_metadata=last_grounding_metadata,
)
last_grounding_metadata = None # Reset after yielding
break
# in case of empty content or parts, we sill surface it
# in case it's an interrupted message, we merge the previous partial
# text. Other we don't merge. because content can be none when model
# safety threshold is triggered
if message.server_content.interrupted:
if text:
yield self.__build_full_text_response(text)
yield self.__build_full_text_response(
text, last_grounding_metadata, interrupted=True
)
text = ''
else:
yield LlmResponse(interrupted=message.server_content.interrupted)
yield LlmResponse(
interrupted=message.server_content.interrupted,
grounding_metadata=last_grounding_metadata,
)
if message.tool_call:
if text:
yield self.__build_full_text_response(text)
yield self.__build_full_text_response(text, last_grounding_metadata)
text = ''
parts = [
types.Part(function_call=function_call)
for function_call in message.tool_call.function_calls
]
yield LlmResponse(content=types.Content(role='model', parts=parts))
yield LlmResponse(
content=types.Content(role='model', parts=parts),
grounding_metadata=last_grounding_metadata,
)
# Note: last_grounding_metadata is NOT reset here because tool_call
# is part of an ongoing turn. The metadata persists until turn_complete
# or interrupted with break, ensuring subsequent messages in the same
# turn can access the grounding information.
if message.session_resumption_update:
logger.debug('Received session resumption message: %s', message)
yield (
Expand Down
Loading