diff --git a/src/database/archive.py b/src/database/archive.py index 2d9cd3d..7690f31 100644 --- a/src/database/archive.py +++ b/src/database/archive.py @@ -7,12 +7,12 @@ class Archive(DatabaseConnection): super().__init__(db_path, table_name) self.uuid = archive_json['id'] - self.repo_id = repo.repo_id + self.repo_id = repo.primary_key self.name = archive_json['name'] self.start = datetime.fromisoformat(archive_json['start']) self.end = datetime.fromisoformat(archive_json['end']) - self.archive_id = self._insert() + self.insert() def _create_table(self): create_statement = f"create table if not exists {self._sql_table}(" \ @@ -29,13 +29,11 @@ class Archive(DatabaseConnection): result = self.sql_execute_one(f"SELECT archive_id FROM {self._sql_table}" f" WHERE uuid=?;", (self.uuid,)) if result is None: - return None + return False, None else: - return result[0] + return True, result[0] def _insert(self) -> int: - if self._exists(): - raise Exception("archive with same uuid already exists") with self.sql_lock: cursor = self.sql_cursor statement = f"INSERT INTO {self._sql_table}"\ diff --git a/src/database/databaseconnection.py b/src/database/databaseconnection.py index c3cf21d..5b95707 100644 --- a/src/database/databaseconnection.py +++ b/src/database/databaseconnection.py @@ -16,6 +16,8 @@ class DatabaseConnection(ABC): self._create_table() self.sql_commit() + self.primary_key = None + @property def sql_lock(self): return self.__sql_lock @@ -66,6 +68,28 @@ class DatabaseConnection(ABC): def sql_commit(self): self.__sql_database.commit() + def insert(self): + if self.exists(): + raise Exception("Record exists") + elif self.primary_key is not None: + raise Exception("Primary key already set") + else: + self.primary_key = self._insert() + + @abstractmethod + def _insert(self) -> int: + raise NotImplementedError + + def exists(self) -> bool: + exists, primary_key = self._exists() + if exists: + self.primary_key = primary_key + return exists + + @abstractmethod + def _exists(self) -> (bool, list): + raise NotImplementedError + @abstractmethod def _create_table(self): raise NotImplementedError diff --git a/src/database/repo.py b/src/database/repo.py index a43452e..848399e 100644 --- a/src/database/repo.py +++ b/src/database/repo.py @@ -6,17 +6,14 @@ class Repo(DatabaseConnection): def __init__(self, db_path, repo_json: dict, table_name: str = 'repo'): super(Repo, self).__init__(db_path, table_name) - self.repo_id = None self.uuid = repo_json['id'] self.location = repo_json['location'] self.last_modified = datetime.fromisoformat(repo_json['last_modified']) - repo_id = self._exists() - if repo_id is None: - self.repo_id = self._insert() - else: - self.repo_id = repo_id + if self.exists(): self._update() + else: + self.insert() def _insert(self) -> int: with self.sql_lock: @@ -31,16 +28,16 @@ class Repo(DatabaseConnection): def _update(self): self.sql_execute(f"UPDATE {self._sql_table} SET location = ?, last_modified = ? WHERE repo_id = ?;", - (self.location, self.last_modified, self.repo_id)) + (self.location, self.last_modified, self.primary_key)) self.sql_commit() def _exists(self): result = self.sql_execute_one(f"SELECT repo_id FROM {self._sql_table}" f" WHERE uuid=?;", (self.uuid,)) if result is None: - return None + return False, None else: - return result[0] + return True, result[0] def _create_table(self): create_statement = f"create table if not exists {self._sql_table}(" \ diff --git a/src/database/stats.py b/src/database/stats.py index 2eb8532..a2f228d 100644 --- a/src/database/stats.py +++ b/src/database/stats.py @@ -5,15 +5,14 @@ class Stats(DatabaseConnection): def __init__(self, db_path, repo: Repo, archive: Archive, stats_json: dict, table_name: str = "stats"): super().__init__(db_path, table_name) - self.stat_id = None - self.repo_id = repo.repo_id - self.archive_id = archive.archive_id + self.repo_id = repo.primary_key + self.archive_id = archive.primary_key self.file_count = stats_json['nfiles'] self.original_size = stats_json['original_size'] self.compressed_size = stats_json['compressed_size'] self.deduplicated_size = stats_json['deduplicated_size'] - self.stat_id = self._insert() + self.insert() def _create_table(self): create_statement = f"create table if not exists {self._sql_table}(" \ @@ -29,7 +28,7 @@ class Stats(DatabaseConnection): self.sql_execute(create_statement) def _exists(self): - return False + return False, None def _insert(self) -> int: with self.sql_lock: