From: Simon Glass <simon.glass@canonical.com> The agent-message-streaming pattern (async iteration, text extraction and conversation-log collection) is duplicated in run() and run_review_agent() Extract it into a shared run_agent_collect() helper. Co-developed-by: Claude Opus 4.6 <noreply@anthropic.com> Signed-off-by: Simon Glass <simon.glass@canonical.com> --- tools/pickman/agent.py | 55 ++++++++++++++++++++----------------- tools/pickman/ftest.py | 62 ++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 92 insertions(+), 25 deletions(-) diff --git a/tools/pickman/agent.py b/tools/pickman/agent.py index 85f8efee1df..63952c1c005 100644 --- a/tools/pickman/agent.py +++ b/tools/pickman/agent.py @@ -55,6 +55,34 @@ def check_available(): return True +async def run_agent_collect(prompt, options): + """Run a Claude agent and collect its conversation log + + Sends the prompt to a Claude agent, streams output to stdout and + collects all text blocks into a conversation log. + + Args: + prompt (str): The prompt to send to the agent + options (ClaudeAgentOptions): Agent configuration + + Returns: + tuple: (success, conversation_log) where success is bool and + conversation_log is the agent's output text + """ + conversation_log = [] + try: + async for message in query(prompt=prompt, options=options): + if hasattr(message, 'content'): + for block in message.content: + if hasattr(block, 'text'): + print(block.text) + conversation_log.append(block.text) + return True, '\n\n'.join(conversation_log) + except (RuntimeError, ValueError, OSError) as exc: + tout.error(f'Agent failed: {exc}') + return False, '\n\n'.join(conversation_log) + + def is_qconfig_commit(subject): """Check if a commit subject indicates a qconfig resync commit @@ -228,19 +256,7 @@ this means the series was already applied via a different path. In this case: tout.info(f'Starting Claude agent to cherry-pick {len(commits)} commits...') tout.info('') - conversation_log = [] - try: - async for message in query(prompt=prompt, options=options): - # Print agent output and capture it - if hasattr(message, 'content'): - for block in message.content: - if hasattr(block, 'text'): - print(block.text) - conversation_log.append(block.text) - return True, '\n\n'.join(conversation_log) - except (RuntimeError, ValueError, OSError) as exc: - tout.error(f'Agent failed: {exc}') - return False, '\n\n'.join(conversation_log) + return await run_agent_collect(prompt, options) def read_signal_file(repo_path=None): @@ -492,18 +508,7 @@ async def run_review_agent(mr_iid, branch_name, comments, remote, tout.info(f'Starting Claude agent to {task_desc}...') tout.info('') - conversation_log = [] - try: - async for message in query(prompt=prompt, options=options): - if hasattr(message, 'content'): - for block in message.content: - if hasattr(block, 'text'): - print(block.text) - conversation_log.append(block.text) - return True, '\n\n'.join(conversation_log) - except (RuntimeError, ValueError, OSError) as exc: - tout.error(f'Agent failed: {exc}') - return False, '\n\n'.join(conversation_log) + return await run_agent_collect(prompt, options) # pylint: disable=too-many-arguments diff --git a/tools/pickman/ftest.py b/tools/pickman/ftest.py index de6bce40614..42ce05962e2 100644 --- a/tools/pickman/ftest.py +++ b/tools/pickman/ftest.py @@ -6,6 +6,7 @@ # pylint: disable=too-many-lines """Tests for pickman.""" +import asyncio import argparse import os import shutil @@ -2971,6 +2972,67 @@ class TestExecuteApply(unittest.TestCase): dbs.close() +class TestRunAgentCollect(unittest.TestCase): + """Tests for run_agent_collect function.""" + + def test_success(self): + """Test successful agent run collects text blocks.""" + block1 = mock.MagicMock() + block1.text = 'hello' + block2 = mock.MagicMock() + block2.text = 'world' + msg = mock.MagicMock() + msg.content = [block1, block2] + + async def fake_query(**kwargs): + yield msg + + with mock.patch.object(agent, 'query', fake_query, create=True): + with terminal.capture(): + opts = mock.MagicMock() + success, log = asyncio.run( + agent.run_agent_collect('prompt', opts)) + + self.assertTrue(success) + self.assertEqual(log, 'hello\n\nworld') + + def test_failure(self): + """Test agent failure returns False with partial log.""" + block = mock.MagicMock() + block.text = 'partial' + msg = mock.MagicMock() + msg.content = [block] + + async def fake_query(**kwargs): + yield msg + raise RuntimeError('agent crashed') + + with mock.patch.object(agent, 'query', fake_query, create=True): + with terminal.capture(): + opts = mock.MagicMock() + success, log = asyncio.run( + agent.run_agent_collect('prompt', opts)) + + self.assertFalse(success) + self.assertEqual(log, 'partial') + + def test_no_content(self): + """Test messages without content are skipped.""" + msg = mock.MagicMock(spec=[]) + + async def fake_query(**kwargs): + yield msg + + with mock.patch.object(agent, 'query', fake_query, create=True): + with terminal.capture(): + opts = mock.MagicMock() + success, log = asyncio.run( + agent.run_agent_collect('prompt', opts)) + + self.assertTrue(success) + self.assertEqual(log, '') + + class TestSignalFile(unittest.TestCase): """Tests for signal file handling.""" -- 2.43.0