Skip to content

Commit

Permalink
Merge pull request #85 from passcod/feat/colsubset
Browse files Browse the repository at this point in the history
Support publication column lists
  • Loading branch information
imor authored Jan 15, 2025
2 parents 1afbd50 + 16b07f9 commit e5c99a8
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 9 deletions.
54 changes: 47 additions & 7 deletions pg_replicate/src/clients/postgres.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::collections::HashMap;
use std::{collections::HashMap, fmt::Write};

use pg_escape::{quote_identifier, quote_literal};
use postgres_replication::LogicalReplicationStream;
Expand Down Expand Up @@ -125,24 +125,59 @@ impl ReplicationClient {
pub async fn get_table_copy_stream(
&self,
table_name: &TableName,
column_schemas: &[ColumnSchema],
) -> Result<CopyOutStream, ReplicationClientError> {
let mut column_iter = column_schemas.iter().map(|col| quote_identifier(&col.name));
let mut column_list = column_iter
.next()
.expect("at least one column in any table or publication column list")
.to_string();
for col in column_iter {
let _ = write!(column_list, ", {col}");
}

let copy_query = format!(
r#"COPY {} TO STDOUT WITH (FORMAT text);"#,
table_name.as_quoted_identifier()
r#"COPY {} ({column_list}) TO STDOUT WITH (FORMAT text);"#,
table_name.as_quoted_identifier(),
);

let stream = self.postgres_client.copy_out_simple(&copy_query).await?;

Ok(stream)
}

/// Returns a vector of columns of a table
/// Returns a vector of columns of a table, optionally filtered by a publication's column list
pub async fn get_column_schemas(
&self,
table_id: TableId,
publication: Option<&str>,
) -> Result<Vec<ColumnSchema>, ReplicationClientError> {
let (pub_cte, pub_pred) = if let Some(publication) = publication {
(
format!(
"with pub_attrs as (
select unnest(r.prattrs)
from pg_publication_rel r
left join pg_publication p on r.prpubid = p.oid
where p.pubname = {publication}
and r.prrelid = {table_id}
)",
publication = quote_literal(publication),
),
"and (
case (select count(*) from pub_attrs)
when 0 then true
else (a.attnum in (select * from pub_attrs))
end
)",
)
} else {
("".into(), "")
};

let column_info_query = format!(
"select a.attname,
"{pub_cte}
select a.attname,
a.atttypid,
a.atttypmod,
a.attnotnull,
Expand All @@ -156,6 +191,7 @@ impl ReplicationClient {
and not a.attisdropped
and a.attgenerated = ''
and a.attrelid = {table_id}
{pub_pred}
order by a.attnum
",
);
Expand Down Expand Up @@ -234,11 +270,14 @@ impl ReplicationClient {
pub async fn get_table_schemas(
&self,
table_names: &[TableName],
publication: Option<&str>,
) -> Result<HashMap<TableId, TableSchema>, ReplicationClientError> {
let mut table_schemas = HashMap::new();

for table_name in table_names {
let table_schema = self.get_table_schema(table_name.clone()).await?;
let table_schema = self
.get_table_schema(table_name.clone(), publication)
.await?;
if !table_schema.has_primary_keys() {
warn!(
"table {} with id {} will not be copied because it has no primary key",
Expand All @@ -255,12 +294,13 @@ impl ReplicationClient {
async fn get_table_schema(
&self,
table_name: TableName,
publication: Option<&str>,
) -> Result<TableSchema, ReplicationClientError> {
let table_id = self
.get_table_id(&table_name)
.await?
.ok_or(ReplicationClientError::MissingTable(table_name.clone()))?;
let column_schemas = self.get_column_schemas(table_id).await?;
let column_schemas = self.get_column_schemas(table_id, publication).await?;
Ok(TableSchema {
table_name,
table_id,
Expand Down
6 changes: 4 additions & 2 deletions pg_replicate/src/pipeline/sources/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ impl PostgresSource {
}
let (table_names, publication) =
Self::get_table_names_and_publication(&replication_client, table_names_from).await?;
let table_schemas = replication_client.get_table_schemas(&table_names).await?;
let table_schemas = replication_client
.get_table_schemas(&table_names, publication.as_deref())
.await?;
Ok(PostgresSource {
replication_client,
table_schemas,
Expand Down Expand Up @@ -125,7 +127,7 @@ impl Source for PostgresSource {

let stream = self
.replication_client
.get_table_copy_stream(table_name)
.get_table_copy_stream(table_name, column_schemas)
.await
.map_err(PostgresSourceError::ReplicationClient)?;

Expand Down

0 comments on commit e5c99a8

Please sign in to comment.