Skip to content

Commit

Permalink
Propagate SECURITY LABEL ON COLUMN to all workers if the table is
Browse files Browse the repository at this point in the history
distributed
  • Loading branch information
colm-mchugh committed Jan 13, 2025
1 parent f7bead2 commit e478669
Show file tree
Hide file tree
Showing 7 changed files with 470 additions and 52 deletions.
56 changes: 45 additions & 11 deletions src/backend/distributed/commands/seclabel.c
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,17 @@
#include "distributed/commands/utility_hook.h"
#include "distributed/coordinator_protocol.h"
#include "distributed/deparser.h"
#include "distributed/listutils.h"
#include "distributed/log_utils.h"
#include "distributed/metadata/distobject.h"
#include "distributed/metadata_sync.h"


/*
* PostprocessSecLabelStmt prepares the commands that need to be run on all workers to assign
* security labels on distributed objects, currently supporting just Role objects.
* It also ensures that all object dependencies exist on all
* nodes for the object in the SecLabelStmt.
* security labels on distributed objects, currently supporting just Role, Table and Column
* objects. It also ensures that all object dependencies exist on all nodes for the object
* in the SecLabelStmt.
*/
List *
PostprocessSecLabelStmt(Node *node, const char *queryString)
Expand All @@ -37,12 +38,14 @@ PostprocessSecLabelStmt(Node *node, const char *queryString)
SecLabelStmt *secLabelStmt = castNode(SecLabelStmt, node);

List *objectAddresses = GetObjectAddressListFromParseTree(node, false, true);
if (!IsAnyObjectDistributed(objectAddresses))
if (!IsAnyObjectDistributedIgnoreObjectSubId(objectAddresses))
{
return NIL;
}

if (secLabelStmt->objtype != OBJECT_ROLE)
if (secLabelStmt->objtype != OBJECT_ROLE &&
secLabelStmt->objtype != OBJECT_TABLE &&
secLabelStmt->objtype != OBJECT_COLUMN)
{
/*
* If we are not in the coordinator, we don't want to interrupt the security
Expand All @@ -52,7 +55,7 @@ PostprocessSecLabelStmt(Node *node, const char *queryString)
if (EnableUnsupportedFeatureMessages && IsCoordinator())
{
ereport(NOTICE, (errmsg("not propagating SECURITY LABEL commands whose "
"object type is not role"),
"object type is not role or table or table column"),
errhint("Connect to worker nodes directly to manually "
"run the same SECURITY LABEL command.")));
}
Expand All @@ -63,13 +66,44 @@ PostprocessSecLabelStmt(Node *node, const char *queryString)
EnsurePropagationToCoordinator();
EnsureAllObjectDependenciesExistOnAllNodes(objectAddresses);

const char *secLabelCommands = DeparseTreeNode((Node *) secLabelStmt);
List *commandList = NULL;

List *commandList = list_make3(DISABLE_DDL_PROPAGATION,
(void *) secLabelCommands,
ENABLE_DDL_PROPAGATION);
if (secLabelStmt->objtype == OBJECT_ROLE ||
secLabelStmt->objtype == OBJECT_TABLE ||
secLabelStmt->objtype == OBJECT_COLUMN)
{
const char *secLabelCommands = DeparseTreeNode((Node *) secLabelStmt);
commandList = list_make3(DISABLE_DDL_PROPAGATION,
(void *) secLabelCommands,
ENABLE_DDL_PROPAGATION);
}

List *DDLJobs = NodeDDLTaskList(REMOTE_NODES, commandList);

/*
* If the label is for a table or a column, we need to set the targetObjectAddress
* of the DDLJob to the relationId of the table. This is needed to ensure that
* the search path is correctly set for the remote security label command; it
* needs to be able to resolve the table that the label is being defined on.
*/
if (secLabelStmt->objtype == OBJECT_TABLE ||
secLabelStmt->objtype == OBJECT_COLUMN)
{
ObjectAddress *target = NULL;
Oid relationId = InvalidOid;
foreach_ptr(target, objectAddresses)
{
relationId = target->objectId;
}
Assert(relationId != InvalidOid);
DDLJob *ddlJob = NULL;
foreach_ptr(ddlJob, DDLJobs)
{
ObjectAddressSet(ddlJob->targetObjectAddress, RelationRelationId, relationId);
}
}

return NodeDDLTaskList(REMOTE_NODES, commandList);
return DDLJobs;
}


Expand Down
31 changes: 30 additions & 1 deletion src/backend/distributed/deparser/deparse_seclabel_stmts.c
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

#include "postgres.h"

#include "catalog/namespace.h"
#include "nodes/parsenodes.h"
#include "utils/builtins.h"

Expand Down Expand Up @@ -54,7 +55,35 @@ AppendSecLabelStmt(StringInfo buf, SecLabelStmt *stmt)
{
case OBJECT_ROLE:
{
appendStringInfo(buf, "ROLE %s ", quote_identifier(strVal(stmt->object)));
char *role_name = strVal(stmt->object);
appendStringInfo(buf, "ROLE %s ", quote_identifier(role_name));
break;
}

case OBJECT_TABLE:
{
List *names = (List *) stmt->object;
appendStringInfo(buf, "TABLE %s", quote_identifier(strVal(linitial(names))));
if (list_length(names) > 1)
{
appendStringInfo(buf, ".%s", quote_identifier(strVal(lsecond(names))));
}
appendStringInfoString(buf, " ");
break;
}

case OBJECT_COLUMN:
{
List *names = (List *) stmt->object;
Assert(list_length(names) >= 2);
appendStringInfo(buf, "COLUMN %s.%s",
quote_identifier(strVal(linitial(names))),
quote_identifier(strVal(lsecond(names))));
if (list_length(names) > 2)
{
appendStringInfo(buf, ".%s", quote_identifier(strVal(lthird(names))));
}
appendStringInfoString(buf, " ");
break;
}

Expand Down
31 changes: 31 additions & 0 deletions src/backend/distributed/metadata/distobject.c
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,37 @@ IsAnyObjectDistributed(const List *addresses)
}


/*
* IsAnyObjectDistributedIgnoreObjectSubId determines if any of the given
* addresses are distributed, using IsObjectDistributed(). It disregards
* the object sub-id field of an address, so this is saved and restored
* before and after each call to IsObjectDistributed(). It is used in
* situations where an address by sub object id is not distributed, but
* the same address by object id is distributed; for example, the address
* of a column of a distributed table.
*/
bool
IsAnyObjectDistributedIgnoreObjectSubId(const List *addresses)
{
ObjectAddress *address = NULL;
bool isDistributed = false;
foreach_ptr(address, addresses)
{
int32 savedObjectSubId = address->objectSubId;
address->objectSubId = 0;
isDistributed = IsObjectDistributed(address);
address->objectSubId = savedObjectSubId;

if (isDistributed)
{
break;
}
}

return isDistributed;
}


/*
* GetDistributedObjectAddressList returns a list of ObjectAddresses that contains all
* distributed objects as marked in pg_dist_object
Expand Down
109 changes: 109 additions & 0 deletions src/backend/distributed/operations/node_protocol.c
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "catalog/pg_constraint.h"
#include "catalog/pg_index.h"
#include "catalog/pg_namespace.h"
#include "catalog/pg_seclabel.h"
#include "catalog/pg_type.h"
#include "commands/sequence.h"
#include "foreign/foreign.h"
Expand All @@ -57,6 +58,7 @@
#include "distributed/citus_ruleutils.h"
#include "distributed/commands.h"
#include "distributed/coordinator_protocol.h"
#include "distributed/deparser.h"
#include "distributed/listutils.h"
#include "distributed/metadata_cache.h"
#include "distributed/metadata_sync.h"
Expand All @@ -83,6 +85,7 @@ static char * CitusCreateAlterColumnarTableSet(char *qualifiedRelationName,
const ColumnarOptions *options);
static char * GetTableDDLCommandColumnar(void *context);
static TableDDLCommand * ColumnarGetTableOptionsDDL(Oid relationId);
static List * CreateSecurityLabelCommands(Oid relationId);

/* exports for SQL callable functions */
PG_FUNCTION_INFO_V1(master_get_table_metadata);
Expand Down Expand Up @@ -665,6 +668,9 @@ GetPreLoadTableCreationCommands(Oid relationId,
List *policyCommands = CreatePolicyCommands(relationId);
tableDDLEventList = list_concat(tableDDLEventList, policyCommands);

List *securityLabelCommands = CreateSecurityLabelCommands(relationId);
tableDDLEventList = list_concat(tableDDLEventList, securityLabelCommands);

/* revert back to original search_path */
PopEmptySearchPath(saveNestLevel);

Expand Down Expand Up @@ -833,6 +839,109 @@ GetTableRowLevelSecurityCommands(Oid relationId)
}


/*
* CreateSecurityLabelCommands takes in a relationId, and returns the
* list of SECURITY LABEL commands on the relation and on the columns
* of the relation. Precondition: relationId identifies a TABLE
*/
static List *
CreateSecurityLabelCommands(Oid relationId)
{
List *securityLabelCommands = NIL;

if (!RegularTable(relationId)) /* should be an Assert ? */
{
return securityLabelCommands;
}

Relation pg_seclabel = table_open(SecLabelRelationId, AccessShareLock);
ScanKeyData skey[1];
ScanKeyInit(&skey[0], Anum_pg_seclabel_objoid, BTEqualStrategyNumber, F_OIDEQ,
ObjectIdGetDatum(relationId));
SysScanDesc scan = systable_beginscan(pg_seclabel, SecLabelObjectIndexId,
true, NULL, 1, &skey[0]);
HeapTuple tuple = NULL;
List *table_name = NIL;
Relation relation = NULL;
TupleDesc tupleDescriptor = NULL;
List *securityLabelStmts = NULL;

while (HeapTupleIsValid(tuple = systable_getnext(scan)))
{
SecLabelStmt *secLabelStmt = makeNode(SecLabelStmt);

if (relation == NULL)
{
relation = relation_open(relationId, AccessShareLock);
if (!RelationIsVisible(relationId))
{
char *nsname = get_namespace_name(RelationGetNamespace(relation));
table_name = lappend(table_name, makeString(nsname));
}
char *relname = get_rel_name(relationId);
table_name = lappend(table_name, makeString(relname));
}

Datum datumArray[Natts_pg_seclabel];
bool isNullArray[Natts_pg_seclabel];
int subObjectId = -1;

heap_deform_tuple(tuple, RelationGetDescr(pg_seclabel), datumArray,
isNullArray);
subObjectId = DatumGetInt32(
datumArray[Anum_pg_seclabel_objsubid - 1]);
secLabelStmt->provider = TextDatumGetCString(
datumArray[Anum_pg_seclabel_provider - 1]);
secLabelStmt->label = TextDatumGetCString(
datumArray[Anum_pg_seclabel_label - 1]);

if (subObjectId > 0)
{
/* Its a column; construct the name */
secLabelStmt->objtype = OBJECT_COLUMN;
List *col_name = list_copy(table_name);

if (tupleDescriptor == NULL)
{
tupleDescriptor = RelationGetDescr(relation);
}

Form_pg_attribute attrForm = TupleDescAttr(tupleDescriptor, subObjectId - 1);
char *attributeName = NameStr(attrForm->attname);
col_name = lappend(col_name, makeString(attributeName));

secLabelStmt->object = (Node *) col_name;
}
else
{
Assert(subObjectId == 0);
secLabelStmt->objtype = OBJECT_TABLE;
secLabelStmt->object = (Node *) table_name;
}

securityLabelStmts = lappend(securityLabelStmts, secLabelStmt);
}

Node *stmt = NULL;
foreach_ptr(stmt, securityLabelStmts)
{
char *secLabelStmtString = DeparseTreeNode(stmt);
TableDDLCommand *secLabelCommand = makeTableDDLCommandString(secLabelStmtString);
securityLabelCommands = lappend(securityLabelCommands, secLabelCommand);
}

systable_endscan(scan);
table_close(pg_seclabel, AccessShareLock);

if (relation != NULL)
{
relation_close(relation, AccessShareLock);
}

return securityLabelCommands;
}


/*
* IndexImpliedByAConstraint is a helper function to be used while scanning
* pg_index. It returns true if the index identified by the given indexForm is
Expand Down
1 change: 1 addition & 0 deletions src/include/distributed/metadata/distobject.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
extern bool ObjectExists(const ObjectAddress *address);
extern bool CitusExtensionObject(const ObjectAddress *objectAddress);
extern bool IsAnyObjectDistributed(const List *addresses);
extern bool IsAnyObjectDistributedIgnoreObjectSubId(const List *addresses);
extern bool ClusterHasDistributedFunctionWithDistArgument(void);
extern void MarkObjectDistributed(const ObjectAddress *distAddress);
extern void MarkObjectDistributedWithName(const ObjectAddress *distAddress, char *name,
Expand Down
Loading

0 comments on commit e478669

Please sign in to comment.