From: Simon Glass <simon.glass@canonical.com> Add an sqlite3 database module to track the state of cherry-picking commits between branches. The database uses .pickman.db and includes: - source table: tracks source branches and their last cherry-picked commit into master - Schema versioning for future migrations The database code is mostly lifted from patman Co-developed-by: Claude Opus 4.5 <noreply@anthropic.com> Signed-off-by: Simon Glass <simon.glass@canonical.com> --- tools/pickman/database.py | 193 ++++++++++++++++++++++++++++++++++++++ tools/pickman/ftest.py | 72 ++++++++++++++ 2 files changed, 265 insertions(+) create mode 100644 tools/pickman/database.py diff --git a/tools/pickman/database.py b/tools/pickman/database.py new file mode 100644 index 00000000000..436734fe1f7 --- /dev/null +++ b/tools/pickman/database.py @@ -0,0 +1,193 @@ +# SPDX-License-Identifier: GPL-2.0+ +# +# Copyright 2025 Canonical Ltd. +# Written by Simon Glass <simon.glass@canonical.com> +# +"""Database for pickman - tracks cherry-pick state. + +This uses sqlite3 with a local file (.pickman.db). + +To adjust the schema, increment LATEST, create a _migrate_to_v<x>() function +and add code in migrate_to() to call it. +""" + +import os +import sqlite3 + +from u_boot_pylib import tools +from u_boot_pylib import tout + +# Schema version (version 0 means there is no database yet) +LATEST = 1 + +# Default database filename +DB_FNAME = '.pickman.db' + + +class Database: + """Database of cherry-pick state used by pickman""" + + # dict of databases: + # key: filename + # value: Database object + instances = {} + + def __init__(self, db_path): + """Set up a new database object + + Args: + db_path (str): Path to the database + """ + if db_path in Database.instances: + raise ValueError(f"There is already a database for '{db_path}'") + self.con = None + self.cur = None + self.db_path = db_path + self.is_open = False + Database.instances[db_path] = self + + @staticmethod + def get_instance(db_path): + """Get the database instance for a path + + Args: + db_path (str): Path to the database + + Return: + tuple: + Database: Database instance, created if necessary + bool: True if newly created + """ + dbs = Database.instances.get(db_path) + if dbs: + return dbs, False + return Database(db_path), True + + def start(self): + """Open the database ready for use, migrate to latest schema""" + self.open_it() + self.migrate_to(LATEST) + + def open_it(self): + """Open the database, creating it if necessary""" + if self.is_open: + raise ValueError('Already open') + if not os.path.exists(self.db_path): + tout.warning(f'Creating new database {self.db_path}') + self.con = sqlite3.connect(self.db_path) + self.cur = self.con.cursor() + self.is_open = True + Database.instances[self.db_path] = self + + def close(self): + """Close the database""" + if not self.is_open: + raise ValueError('Already closed') + self.con.close() + self.cur = None + self.con = None + self.is_open = False + Database.instances.pop(self.db_path, None) + + def _create_v1(self): + """Create a database with the v1 schema""" + # Table for tracking source branches and their last cherry-picked commit + self.cur.execute( + 'CREATE TABLE source (' + 'id INTEGER PRIMARY KEY AUTOINCREMENT, ' + 'name TEXT UNIQUE, ' + 'last_commit TEXT)') + + # Schema version table + self.cur.execute('CREATE TABLE schema_version (version INTEGER)') + + def migrate_to(self, dest_version): + """Migrate the database to the selected version + + Args: + dest_version (int): Version to migrate to + """ + while True: + version = self.get_schema_version() + if version >= dest_version: + break + + self.close() + tools.write_file(f'{self.db_path}old.v{version}', + tools.read_file(self.db_path)) + + version += 1 + tout.info(f'Update database to v{version}') + self.open_it() + if version == 1: + self._create_v1() + + self.cur.execute('DELETE FROM schema_version') + self.cur.execute( + 'INSERT INTO schema_version (version) VALUES (?)', + (version,)) + self.commit() + + def get_schema_version(self): + """Get the version of the database's schema + + Return: + int: Database version, 0 means there is no data + """ + try: + self.cur.execute('SELECT version FROM schema_version') + return self.cur.fetchone()[0] + except sqlite3.OperationalError: + return 0 + + def execute(self, query, parameters=()): + """Execute a database query + + Args: + query (str): Query string + parameters (tuple): Parameters to pass + + Return: + Cursor result + """ + return self.cur.execute(query, parameters) + + def commit(self): + """Commit changes to the database""" + self.con.commit() + + def rollback(self): + """Roll back changes to the database""" + self.con.rollback() + + # source functions + + def source_get(self, name): + """Get the last cherry-picked commit for a source branch + + Args: + name (str): Source branch name + + Return: + str: Commit hash, or None if not found + """ + res = self.execute( + 'SELECT last_commit FROM source WHERE name = ?', (name,)) + rec = res.fetchone() + if rec: + return rec[0] + return None + + def source_set(self, name, commit): + """Set the last cherry-picked commit for a source branch + + Args: + name (str): Source branch name + commit (str): Commit hash + """ + self.execute( + 'UPDATE source SET last_commit = ? WHERE name = ?', (commit, name)) + if self.cur.rowcount == 0: + self.execute( + 'INSERT INTO source (name, last_commit) VALUES (?, ?)', + (name, commit)) diff --git a/tools/pickman/ftest.py b/tools/pickman/ftest.py index eeb19926f76..b975b9c6a2b 100644 --- a/tools/pickman/ftest.py +++ b/tools/pickman/ftest.py @@ -7,6 +7,7 @@ import os import sys +import tempfile import unittest # Allow 'from pickman import xxx' to work via symlink @@ -19,6 +20,7 @@ from u_boot_pylib import terminal from pickman import __main__ as pickman from pickman import control +from pickman import database class TestCommit(unittest.TestCase): @@ -152,5 +154,75 @@ class TestMain(unittest.TestCase): command.TEST_RESULT = None +class TestDatabase(unittest.TestCase): + """Tests for Database class.""" + + def setUp(self): + """Set up test fixtures.""" + fd, self.db_path = tempfile.mkstemp(suffix='.db') + os.close(fd) + os.unlink(self.db_path) # Remove so database creates it fresh + database.Database.instances.clear() + + def tearDown(self): + """Clean up test fixtures.""" + if os.path.exists(self.db_path): + os.unlink(self.db_path) + database.Database.instances.clear() + + def test_create_database(self): + """Test creating a new database.""" + with terminal.capture(): + dbs = database.Database(self.db_path) + dbs.start() + self.assertTrue(dbs.is_open) + self.assertEqual(dbs.get_schema_version(), database.LATEST) + dbs.close() + + def test_source_get_empty(self): + """Test getting source from empty database.""" + with terminal.capture(): + dbs = database.Database(self.db_path) + dbs.start() + result = dbs.source_get('us/next') + self.assertIsNone(result) + dbs.close() + + def test_source_set_and_get(self): + """Test setting and getting source commit.""" + with terminal.capture(): + dbs = database.Database(self.db_path) + dbs.start() + dbs.source_set('us/next', 'abc123def456') + dbs.commit() + result = dbs.source_get('us/next') + self.assertEqual(result, 'abc123def456') + dbs.close() + + def test_source_update(self): + """Test updating source commit.""" + with terminal.capture(): + dbs = database.Database(self.db_path) + dbs.start() + dbs.source_set('us/next', 'abc123') + dbs.commit() + dbs.source_set('us/next', 'def456') + dbs.commit() + result = dbs.source_get('us/next') + self.assertEqual(result, 'def456') + dbs.close() + + def test_get_instance(self): + """Test get_instance returns same database.""" + with terminal.capture(): + dbs1, created1 = database.Database.get_instance(self.db_path) + dbs1.start() + dbs2, created2 = database.Database.get_instance(self.db_path) + self.assertTrue(created1) + self.assertFalse(created2) + self.assertIs(dbs1, dbs2) + dbs1.close() + + if __name__ == '__main__': unittest.main() -- 2.43.0