@ -22,10 +22,26 @@ import sys
import time
import traceback
from types import TracebackType
from typing import Dict , Iterable , Optional , Set , Tuple , Type , cast
from typing import (
Any ,
Awaitable ,
Callable ,
Dict ,
Generator ,
Iterable ,
List ,
NoReturn ,
Optional ,
Set ,
Tuple ,
Type ,
TypeVar ,
cast ,
)
import yaml
from matrix_common . versionstring import get_distribution_version_string
from typing_extensions import TypedDict
from twisted . internet import defer , reactor as reactor_
@ -36,7 +52,7 @@ from synapse.logging.context import (
make_deferred_yieldable ,
run_in_background ,
)
from synapse . storage . database import DatabasePool , make_conn
from synapse . storage . database import DatabasePool , LoggingTransaction , make_conn
from synapse . storage . databases . main import PushRuleStore
from synapse . storage . databases . main . account_data import AccountDataWorkerStore
from synapse . storage . databases . main . client_ips import ClientIpBackgroundUpdateStore
@ -173,6 +189,8 @@ end_error_exec_info: Optional[
Tuple [ Type [ BaseException ] , BaseException , TracebackType ]
] = None
R = TypeVar ( " R " )
class Store (
ClientIpBackgroundUpdateStore ,
@ -195,17 +213,19 @@ class Store(
PresenceBackgroundUpdateStore ,
GroupServerWorkerStore ,
) :
def execute ( self , f , * args , * * kwargs ) :
def execute ( self , f : Callable [ . . . , R ] , * args : Any , * * kwargs : Any ) - > Awaitable [ R ] :
return self . db_pool . runInteraction ( f . __name__ , f , * args , * * kwargs )
def execute_sql ( self , sql , * args ) :
def r ( txn ) :
def execute_sql ( self , sql : str , * args : object ) - > Awaitable [ List [ Tuple ] ] :
def r ( txn : LoggingTransaction ) - > List [ Tuple ] :
txn . execute ( sql , args )
return txn . fetchall ( )
return self . db_pool . runInteraction ( " execute_sql " , r )
def insert_many_txn ( self , txn , table , headers , rows ) :
def insert_many_txn (
self , txn : LoggingTransaction , table : str , headers : List [ str ] , rows : List [ Tuple ]
) - > None :
sql = " INSERT INTO %s ( %s ) VALUES ( %s ) " % (
table ,
" , " . join ( k for k in headers ) ,
@ -218,14 +238,15 @@ class Store(
logger . exception ( " Failed to insert: %s " , table )
raise
def set_room_is_public ( self , room_id , is_public ) :
# Note: the parent method is an `async def`.
def set_room_is_public ( self , room_id : str , is_public : bool ) - > NoReturn :
raise Exception (
" Attempt to set room_is_public during port_db: database not empty? "
)
class MockHomeserver :
def __init__ ( self , config ) :
def __init__ ( self , config : HomeServerConfig ) :
self . clock = Clock ( reactor )
self . config = config
self . hostname = config . server . server_name
@ -233,24 +254,30 @@ class MockHomeserver:
" matrix-synapse "
)
def get_clock ( self ) :
def get_clock ( self ) - > Clock :
return self . clock
def get_reactor ( self ) :
def get_reactor ( self ) - > ISynapseReactor :
return reactor
def get_instance_name ( self ) :
def get_instance_name ( self ) - > str :
return " master "
class Porter :
def __init__ ( self , sqlite_config , progress , batch_size , hs_config ) :
def __init__ (
self ,
sqlite_config : Dict [ str , Any ] ,
progress : " Progress " ,
batch_size : int ,
hs_config : HomeServerConfig ,
) :
self . sqlite_config = sqlite_config
self . progress = progress
self . batch_size = batch_size
self . hs_config = hs_config
async def setup_table ( self , table ) :
async def setup_table ( self , table : str ) - > Tuple [ str , int , int , int , int ] :
if table in APPEND_ONLY_TABLES :
# It's safe to just carry on inserting.
row = await self . postgres_store . db_pool . simple_select_one (
@ -292,7 +319,7 @@ class Porter:
)
else :
def delete_all ( txn ) :
def delete_all ( txn : LoggingTransaction ) - > None :
txn . execute (
" DELETE FROM port_from_sqlite3 WHERE table_name = %s " , ( table , )
)
@ -317,7 +344,7 @@ class Porter:
async def get_table_constraints ( self ) - > Dict [ str , Set [ str ] ] :
""" Returns a map of tables that have foreign key constraints to tables they depend on. """
def _get_constraints ( txn ) :
def _get_constraints ( txn : LoggingTransaction ) - > Dict [ str , Set [ str ] ] :
# We can pull the information about foreign key constraints out from
# the postgres schema tables.
sql = """
@ -343,8 +370,13 @@ class Porter:
)
async def handle_table (
self , table , postgres_size , table_size , forward_chunk , backward_chunk
) :
self ,
table : str ,
postgres_size : int ,
table_size : int ,
forward_chunk : int ,
backward_chunk : int ,
) - > None :
logger . info (
" Table %s : %i / %i (rows %i - %i ) already ported " ,
table ,
@ -391,7 +423,9 @@ class Porter:
while True :
def r ( txn ) :
def r (
txn : LoggingTransaction ,
) - > Tuple [ Optional [ List [ str ] ] , List [ Tuple ] , List [ Tuple ] ] :
forward_rows = [ ]
backward_rows = [ ]
if do_forward [ 0 ] :
@ -418,6 +452,7 @@ class Porter:
)
if frows or brows :
assert headers is not None
if frows :
forward_chunk = max ( row [ 0 ] for row in frows ) + 1
if brows :
@ -426,7 +461,8 @@ class Porter:
rows = frows + brows
rows = self . _convert_rows ( table , headers , rows )
def insert ( txn ) :
def insert ( txn : LoggingTransaction ) - > None :
assert headers is not None
self . postgres_store . insert_many_txn ( txn , table , headers [ 1 : ] , rows )
self . postgres_store . db_pool . simple_update_one_txn (
@ -448,8 +484,12 @@ class Porter:
return
async def handle_search_table (
self , postgres_size , table_size , forward_chunk , backward_chunk
) :
self ,
postgres_size : int ,
table_size : int ,
forward_chunk : int ,
backward_chunk : int ,
) - > None :
select = (
" SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering "
" FROM event_search as es "
@ -460,7 +500,7 @@ class Porter:
while True :
def r ( txn ) :
def r ( txn : LoggingTransaction ) - > Tuple [ List [ str ] , List [ Tuple ] ] :
txn . execute ( select , ( forward_chunk , self . batch_size ) )
rows = txn . fetchall ( )
headers = [ column [ 0 ] for column in txn . description ]
@ -474,7 +514,7 @@ class Porter:
# We have to treat event_search differently since it has a
# different structure in the two different databases.
def insert ( txn ) :
def insert ( txn : LoggingTransaction ) - > None :
sql = (
" INSERT INTO event_search (event_id, room_id, key, "
" sender, vector, origin_server_ts, stream_ordering) "
@ -528,7 +568,7 @@ class Porter:
self ,
db_config : DatabaseConnectionConfig ,
allow_outdated_version : bool = False ,
) :
) - > Store :
""" Builds and returns a database store using the provided configuration.
Args :
@ -556,7 +596,7 @@ class Porter:
return store
async def run_background_updates_on_postgres ( self ) :
async def run_background_updates_on_postgres ( self ) - > None :
# Manually apply all background updates on the PostgreSQL database.
postgres_ready = (
await self . postgres_store . db_pool . updates . has_completed_background_updates ( )
@ -568,12 +608,12 @@ class Porter:
self . progress . set_state ( " Running background updates on PostgreSQL " )
while not postgres_ready :
await self . postgres_store . db_pool . updates . do_next_background_update ( 100 )
await self . postgres_store . db_pool . updates . do_next_background_update ( True )
postgres_ready = await (
self . postgres_store . db_pool . updates . has_completed_background_updates ( )
)
async def run ( self ) :
async def run ( self ) - > None :
""" Ports the SQLite database to a PostgreSQL database.
When a fatal error is met , its message is assigned to the global " end_error "
@ -609,7 +649,7 @@ class Porter:
self . progress . set_state ( " Creating port tables " )
def create_port_table ( txn ) :
def create_port_table ( txn : LoggingTransaction ) - > None :
txn . execute (
" CREATE TABLE IF NOT EXISTS port_from_sqlite3 ( "
" table_name varchar(100) NOT NULL UNIQUE, "
@ -622,7 +662,7 @@ class Porter:
# We want people to be able to rerun this script from an old port
# so that they can pick up any missing events that were not
# ported across.
def alter_table ( txn ) :
def alter_table ( txn : LoggingTransaction ) - > None :
txn . execute (
" ALTER TABLE IF EXISTS port_from_sqlite3 "
" RENAME rowid TO forward_rowid "
@ -742,7 +782,9 @@ class Porter:
finally :
reactor . stop ( )
def _convert_rows ( self , table , headers , rows ) :
def _convert_rows (
self , table : str , headers : List [ str ] , rows : List [ Tuple ]
) - > List [ Tuple ] :
bool_col_names = BOOLEAN_COLUMNS . get ( table , [ ] )
bool_cols = [ i for i , h in enumerate ( headers ) if h in bool_col_names ]
@ -750,7 +792,7 @@ class Porter:
class BadValueException ( Exception ) :
pass
def conv ( j , col ) :
def conv ( j : int , col : object ) - > object :
if j in bool_cols :
return bool ( col )
if isinstance ( col , bytes ) :
@ -776,7 +818,7 @@ class Porter:
return outrows
async def _setup_sent_transactions ( self ) :
async def _setup_sent_transactions ( self ) - > Tuple [ int , int , int ] :
# Only save things from the last day
yesterday = int ( time . time ( ) * 1000 ) - 86400000
@ -788,10 +830,10 @@ class Porter:
" ) "
)
def r ( txn ) :
def r ( txn : LoggingTransaction ) - > Tuple [ List [ str ] , List [ Tuple ] ] :
txn . execute ( select )
rows = txn . fetchall ( )
headers = [ column [ 0 ] for column in txn . description ]
headers : List [ str ] = [ column [ 0 ] for column in txn . description ]
ts_ind = headers . index ( " ts " )
@ -805,7 +847,7 @@ class Porter:
if inserted_rows :
max_inserted_rowid = max ( r [ 0 ] for r in rows )
def insert ( txn ) :
def insert ( txn : LoggingTransaction ) - > None :
self . postgres_store . insert_many_txn (
txn , " sent_transactions " , headers [ 1 : ] , rows
)
@ -814,7 +856,7 @@ class Porter:
else :
max_inserted_rowid = 0
def get_start_id ( txn ) :
def get_start_id ( txn : LoggingTransaction ) - > int :
txn . execute (
" SELECT rowid FROM sent_transactions WHERE ts >= ? "
" ORDER BY rowid ASC LIMIT 1 " ,
@ -839,12 +881,13 @@ class Porter:
} ,
)
def get_sent_table_size ( txn ) :
def get_sent_table_size ( txn : LoggingTransaction ) - > int :
txn . execute (
" SELECT count(*) FROM sent_transactions " " WHERE ts >= ? " , ( yesterday , )
)
( size , ) = txn . fetchone ( )
return int ( size )
result = txn . fetchone ( )
assert result is not None
return int ( result [ 0 ] )
remaining_count = await self . sqlite_store . execute ( get_sent_table_size )
@ -852,25 +895,35 @@ class Porter:
return next_chunk , inserted_rows , total_count
async def _get_remaining_count_to_port ( self , table , forward_chunk , backward_chunk ) :
frows = await self . sqlite_store . execute_sql (
" SELECT count(*) FROM %s WHERE rowid >= ? " % ( table , ) , forward_chunk
async def _get_remaining_count_to_port (
self , table : str , forward_chunk : int , backward_chunk : int
) - > int :
frows = cast (
List [ Tuple [ int ] ] ,
await self . sqlite_store . execute_sql (
" SELECT count(*) FROM %s WHERE rowid >= ? " % ( table , ) , forward_chunk
) ,
)
brows = await self . sqlite_store . execute_sql (
" SELECT count(*) FROM %s WHERE rowid <= ? " % ( table , ) , backward_chunk
brows = cast (
List [ Tuple [ int ] ] ,
await self . sqlite_store . execute_sql (
" SELECT count(*) FROM %s WHERE rowid <= ? " % ( table , ) , backward_chunk
) ,
)
return frows [ 0 ] [ 0 ] + brows [ 0 ] [ 0 ]
async def _get_already_ported_count ( self , table ) :
async def _get_already_ported_count ( self , table : str ) - > int :
rows = await self . postgres_store . execute_sql (
" SELECT count(*) FROM %s " % ( table , )
)
return rows [ 0 ] [ 0 ]
async def _get_total_count_to_port ( self , table , forward_chunk , backward_chunk ) :
async def _get_total_count_to_port (
self , table : str , forward_chunk : int , backward_chunk : int
) - > Tuple [ int , int ] :
remaining , done = await make_deferred_yieldable (
defer . gatherResults (
[
@ -891,14 +944,17 @@ class Porter:
return done , remaining + done
async def _setup_state_group_id_seq ( self ) - > None :
curr_id = await self . sqlite_store . db_pool . simple_select_one_onecol (
curr_id : Optional [
int
] = await self . sqlite_store . db_pool . simple_select_one_onecol (
table = " state_groups " , keyvalues = { } , retcol = " MAX(id) " , allow_none = True
)
if not curr_id :
return
def r ( txn ) :
def r ( txn : LoggingTransaction ) - > None :
assert curr_id is not None
next_id = curr_id + 1
txn . execute ( " ALTER SEQUENCE state_group_id_seq RESTART WITH %s " , ( next_id , ) )
@ -909,7 +965,7 @@ class Porter:
" setup_user_id_seq " , find_max_generated_user_id_localpart
)
def r ( txn ) :
def r ( txn : LoggingTransaction ) - > None :
next_id = curr_id + 1
txn . execute ( " ALTER SEQUENCE user_id_seq RESTART WITH %s " , ( next_id , ) )
@ -931,7 +987,7 @@ class Porter:
allow_none = True ,
)
def _setup_events_stream_seqs_set_pos ( txn ) :
def _setup_events_stream_seqs_set_pos ( txn : LoggingTransaction ) - > None :
if curr_forward_id :
txn . execute (
" ALTER SEQUENCE events_stream_seq RESTART WITH %s " ,
@ -955,17 +1011,20 @@ class Porter:
""" Set a sequence to the correct value. """
current_stream_ids = [ ]
for stream_id_table in stream_id_tables :
max_stream_id = await self . sqlite_store . db_pool . simple_select_one_onecol (
table = stream_id_table ,
keyvalues = { } ,
retcol = " COALESCE(MAX(stream_id), 1) " ,
allow_none = True ,
max_stream_id = cast (
int ,
await self . sqlite_store . db_pool . simple_select_one_onecol (
table = stream_id_table ,
keyvalues = { } ,
retcol = " COALESCE(MAX(stream_id), 1) " ,
allow_none = True ,
) ,
)
current_stream_ids . append ( max_stream_id )
next_id = max ( current_stream_ids ) + 1
def r ( txn ) :
def r ( txn : LoggingTransaction ) - > None :
sql = " ALTER SEQUENCE %s RESTART WITH " % ( sequence_name , )
txn . execute ( sql + " %s " , ( next_id , ) )
@ -974,14 +1033,18 @@ class Porter:
)
async def _setup_auth_chain_sequence ( self ) - > None :
curr_chain_id = await self . sqlite_store . db_pool . simple_select_one_onecol (
curr_chain_id : Optional [
int
] = await self . sqlite_store . db_pool . simple_select_one_onecol (
table = " event_auth_chains " ,
keyvalues = { } ,
retcol = " MAX(chain_id) " ,
allow_none = True ,
)
def r ( txn ) :
def r ( txn : LoggingTransaction ) - > None :
# Presumably there is at least one row in event_auth_chains.
assert curr_chain_id is not None
txn . execute (
" ALTER SEQUENCE event_auth_chain_id RESTART WITH %s " ,
( curr_chain_id + 1 , ) ,
@ -999,15 +1062,22 @@ class Porter:
##############################################
class Progress ( object ) :
class TableProgress ( TypedDict ) :
start : int
num_done : int
total : int
perc : int
class Progress :
""" Used to report progress of the port """
def __init__ ( self ) :
self . tables = { }
def __init__ ( self ) - > None :
self . tables : Dict [ str , TableProgress ] = { }
self . start_time = int ( time . time ( ) )
def add_table ( self , table , cur , size ) :
def add_table ( self , table : str , cur : int , size : int ) - > None :
self . tables [ table ] = {
" start " : cur ,
" num_done " : cur ,
@ -1015,19 +1085,22 @@ class Progress(object):
" perc " : int ( cur * 100 / size ) ,
}
def update ( self , table , num_done ) :
def update ( self , table : str , num_done : int ) - > None :
data = self . tables [ table ]
data [ " num_done " ] = num_done
data [ " perc " ] = int ( num_done * 100 / data [ " total " ] )
def done ( self ) :
def done ( self ) - > None :
pass
def set_state ( self , state : str ) - > None :
pass
class CursesProgress ( Progress ) :
""" Reports progress to a curses window """
def __init__ ( self , stdscr ) :
def __init__ ( self , stdscr : " curses.window " ) :
self . stdscr = stdscr
curses . use_default_colors ( )
@ -1045,7 +1118,7 @@ class CursesProgress(Progress):
super ( CursesProgress , self ) . __init__ ( )
def update ( self , table , num_done ) :
def update ( self , table : str , num_done : int ) - > None :
super ( CursesProgress , self ) . update ( table , num_done )
self . total_processed = 0
@ -1056,7 +1129,7 @@ class CursesProgress(Progress):
self . render ( )
def render ( self , force = False ) :
def render ( self , force : bool = False ) - > None :
now = time . time ( )
if not force and now - self . last_update < 0.2 :
@ -1128,12 +1201,12 @@ class CursesProgress(Progress):
self . stdscr . refresh ( )
self . last_update = time . time ( )
def done ( self ) :
def done ( self ) - > None :
self . finished = True
self . render ( True )
self . stdscr . getch ( )
def set_state ( self , state ) :
def set_state ( self , state : str ) - > None :
self . stdscr . clear ( )
self . stdscr . addstr ( 0 , 0 , state + " ... " , curses . A_BOLD )
self . stdscr . refresh ( )
@ -1142,7 +1215,7 @@ class CursesProgress(Progress):
class TerminalProgress ( Progress ) :
""" Just prints progress to the terminal """
def update ( self , table , num_done ) :
def update ( self , table : str , num_done : int ) - > None :
super ( TerminalProgress , self ) . update ( table , num_done )
data = self . tables [ table ]
@ -1151,7 +1224,7 @@ class TerminalProgress(Progress):
" %s : %d %% ( %d / %d ) " % ( table , data [ " perc " ] , data [ " num_done " ] , data [ " total " ] )
)
def set_state ( self , state ) :
def set_state ( self , state : str ) - > None :
print ( state + " ... " )
@ -1159,7 +1232,7 @@ class TerminalProgress(Progress):
##############################################
def main ( ) :
def main ( ) - > None :
parser = argparse . ArgumentParser (
description = " A script to port an existing synapse SQLite database to "
" a new PostgreSQL database. "
@ -1225,7 +1298,7 @@ def main():
config = HomeServerConfig ( )
config . parse_config_dict ( hs_config , " " , " " )
def start ( stdscr = None ) :
def start ( stdscr : Optional [ " curses.window " ] = None ) - > None :
progress : Progress
if stdscr :
progress = CursesProgress ( stdscr )
@ -1240,7 +1313,7 @@ def main():
)
@defer . inlineCallbacks
def run ( ) :
def run ( ) - > Generator [ " defer.Deferred[Any] " , Any , None ] :
with LoggingContext ( " synapse_port_db_run " ) :
yield defer . ensureDeferred ( porter . run ( ) )