def split_link_prediction_dataset(
X: Float[torch.Tensor, "n_pairs n_dims"],
y: Int[torch.Tensor, "n_pairs"],
test_size: float = 0.2,
downsample: int | None = None,
random_state: int | None = None,
**kwargs: Any,
) -> tuple[
Float[torch.Tensor, "... n_dims"],
Float[torch.Tensor, "... n_dims"],
Int[torch.Tensor, "..."],
Int[torch.Tensor, "..."],
Int[torch.Tensor, "..."],
Int[torch.Tensor, "..."],
]:
"""Split a link prediction dataset into train and test sets.
Args:
X: Node-pair embeddings of shape (n_nodes^2, n_dims).
y: Edge labels of shape (n_nodes^2,).
test_size: Proportion of nodes to include in test set.
downsample: If provided, downsample to this many pos/neg pairs each.
random_state: Random seed for reproducibility.
**kwargs: Additional arguments for train_test_split.
Returns:
Tuple of (X_train, X_test, y_train, y_test, idx_train, idx_test).
"""
if random_state is not None:
torch.manual_seed(random_state)
n_pairs, n_dims = X.shape
n_nodes = int(n_pairs**0.5)
assert n_nodes**2 == n_pairs, f"Expected {n_nodes}^2 = {n_nodes**2} pairs, got {n_pairs}"
# Downsample if requested (before split to maintain structure)
if downsample is not None:
pos_mask = y == 1
neg_mask = y == 0
pos_indices = torch.where(pos_mask)[0]
neg_indices = torch.where(neg_mask)[0]
# Sample up to 'downsample' examples from each class
n_pos = min(len(pos_indices), downsample)
n_neg = min(len(neg_indices), downsample)
sampled_pos = pos_indices[torch.randperm(len(pos_indices))[:n_pos]]
sampled_neg = neg_indices[torch.randperm(len(neg_indices))[:n_neg]]
# Create a mask for selected pairs
mask = torch.zeros(n_pairs, dtype=torch.bool)
mask[sampled_pos] = True
mask[sampled_neg] = True
# Zero out unselected pairs
X_filtered = X.clone()
y_filtered = y.clone()
X_filtered[~mask] = 0
y_filtered[~mask] = 0
else:
X_filtered = X
y_filtered = y
# Reshape to adjacency format
X_adj = X_filtered.view(n_nodes, n_nodes, n_dims)
y_adj = y_filtered.view(n_nodes, n_nodes)
# Split nodes into train/test
node_indices = torch.arange(n_nodes)
idx_train, idx_test = train_test_split(node_indices, test_size=test_size, random_state=random_state, **kwargs)
# Extract train and test subgraphs and flatten
X_train = X_adj[idx_train][:, idx_train].reshape(-1, n_dims)
y_train = y_adj[idx_train][:, idx_train].reshape(-1)
X_test = X_adj[idx_test][:, idx_test].reshape(-1, n_dims)
y_test = y_adj[idx_test][:, idx_test].reshape(-1)
return X_train, X_test, y_train, y_test, idx_train, idx_test