diff --git a/contrib/pg_tde/src/pg_tde_event_capture.c b/contrib/pg_tde/src/pg_tde_event_capture.c index 1f13c677a14..3a5bcbfc9b0 100644 --- a/contrib/pg_tde/src/pg_tde_event_capture.c +++ b/contrib/pg_tde/src/pg_tde_event_capture.c @@ -46,6 +46,15 @@ GetCurrentTdeCreateEvent(void) return &tdeCurrentCreateEvent; } +static bool +shouldEncryptTable(const char *accessMethod) +{ + if (accessMethod) + return strcmp(accessMethod, "tde_heap") == 0; + else + return strcmp(default_table_access_method, "tde_heap") == 0; +} + static void checkPrincipalKeyConfigured(void) { @@ -56,17 +65,8 @@ checkPrincipalKeyConfigured(void) } static void -checkEncryptionClause(const char *accessMethod) +checkEncryptionStatus(void) { - if (accessMethod && strcmp(accessMethod, "tde_heap") == 0) - { - tdeCurrentCreateEvent.encryptMode = true; - } - else if (accessMethod == NULL && strcmp(default_table_access_method, "tde_heap") == 0) - { - tdeCurrentCreateEvent.encryptMode = true; - } - if (tdeCurrentCreateEvent.encryptMode) { checkPrincipalKeyConfigured(); @@ -155,27 +155,30 @@ pg_tde_ddl_command_start_capture(PG_FUNCTION_ARGS) else if (IsA(parsetree, CreateStmt)) { CreateStmt *stmt = (CreateStmt *) parsetree; - const char *accessMethod = stmt->accessMethod; validateCurrentEventTriggerState(true); tdeCurrentCreateEvent.tid = GetCurrentFullTransactionId(); - tdeCurrentCreateEvent.relation = stmt->relation; - checkEncryptionClause(accessMethod); + if (shouldEncryptTable(stmt->accessMethod)) + tdeCurrentCreateEvent.encryptMode = true; + + checkEncryptionStatus(); } else if (IsA(parsetree, CreateTableAsStmt)) { CreateTableAsStmt *stmt = (CreateTableAsStmt *) parsetree; - const char *accessMethod = stmt->into->accessMethod; validateCurrentEventTriggerState(true); tdeCurrentCreateEvent.tid = GetCurrentFullTransactionId(); tdeCurrentCreateEvent.relation = stmt->into->rel; - checkEncryptionClause(accessMethod); + if (shouldEncryptTable(stmt->into->accessMethod)) + tdeCurrentCreateEvent.encryptMode = true; + + checkEncryptionStatus(); } else if (IsA(parsetree, AlterTableStmt)) { @@ -192,13 +195,15 @@ pg_tde_ddl_command_start_capture(PG_FUNCTION_ARGS) if (cmd->subtype == AT_SetAccessMethod) { - const char *accessMethod = cmd->name; - tdeCurrentCreateEvent.relation = stmt->relation; tdeCurrentCreateEvent.baseTableOid = relationId; tdeCurrentCreateEvent.alterAccessMethodMode = true; - checkEncryptionClause(accessMethod); + if (shouldEncryptTable(cmd->name)) + tdeCurrentCreateEvent.encryptMode = true; + + checkEncryptionStatus(); + alterSetAccessMethod = true; } }