Chapter 3: Tensor Operations¶
⚙️ Learning Objectives
- Perform mathematical operations on tensors
- Understand matrix operations
- Reshape and manipulate tensor dimensions
- Use broadcasting effectively
Tensor Operations Overview
Tensor operations are the building blocks of all PyTorch computations. Understanding element-wise operations, matrix operations, and reshaping is crucial for building neural networks.
Mathematical Operations¶
Element-wise Operations¶
import torch
x = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
y = torch.tensor([[5, 6], [7, 8]], dtype=torch.float32)
# Addition
print(f"x + y:\n{x + y}")
print(f"torch.add(x, y):\n{torch.add(x, y)}")
# Subtraction
print(f"x - y:\n{x - y}")
print(f"torch.sub(x, y):\n{torch.sub(x, y)}")
# Multiplication
print(f"x * y:\n{x * y}")
print(f"torch.mul(x, y):\n{torch.mul(x, y)}")
# Division
print(f"x / y:\n{x / y}")
print(f"torch.div(x, y):\n{torch.div(x, y)}")
# Power
print(f"x ** 2:\n{x ** 2}")
print(f"torch.pow(x, 2):\n{torch.pow(x, 2)}")
# Square root
print(f"sqrt(x):\n{torch.sqrt(x)}")
# Exponential
print(f"exp(x):\n{torch.exp(x)}")
# Logarithm
print(f"log(x):\n{torch.log(x)}")
In-place Operations¶
import torch
x = torch.tensor([1.0, 2.0, 3.0])
print(f"Original x: {x}")
# In-place operations (end with _)
x.add_(10) # x = x + 10
print(f"After add_(10): {x}")
x.mul_(2) # x = x * 2
print(f"After mul_(2): {x}")
x.sqrt_() # x = sqrt(x)
print(f"After sqrt_(): {x}")
# Warning: In-place operations save memory but modify original tensor
In-place Operations
Use in-place operations (ending with _) to save memory, but be careful - they modify the original tensor and can break gradient computation. Avoid them when requires_grad=True.
Gradient Tracking
In-place operations on tensors with requires_grad=True can cause issues during backpropagation. Use regular operations instead when gradients are needed.
Comparison Operations¶
import torch
x = torch.tensor([1, 2, 3, 4, 5])
y = torch.tensor([5, 4, 3, 2, 1])
# Element-wise comparison
print(f"x == y: {x == y}")
print(f"x > y: {x > y}")
print(f"x < y: {x < y}")
print(f"x >= y: {x >= y}")
# Logical operations
print(f"x > 2 and x < 5: {(x > 2) & (x < 5)}")
print(f"x < 2 or x > 4: {(x < 2) | (x > 4)}")
# Finding elements
print(f"Max: {torch.max(x)}")
print(f"Min: {torch.min(x)}")
print(f"Argmax: {torch.argmax(x)}")
print(f"Argmin: {torch.argmin(x)}")
Matrix Operations¶
Matrix Multiplication¶
import torch
# 2D matrix multiplication
A = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
B = torch.tensor([[5, 6], [7, 8]], dtype=torch.float32)
# Method 1: @ operator
C1 = A @ B
print(f"A @ B:\n{C1}")
# Method 2: torch.mm
C2 = torch.mm(A, B)
print(f"torch.mm(A, B):\n{C2}")
# Method 3: torch.matmul (works for batched matrices too)
C3 = torch.matmul(A, B)
print(f"torch.matmul(A, B):\n{C3}")
# Element-wise multiplication (NOT matrix multiplication)
element_wise = A * B
print(f"Element-wise (A * B):\n{element_wise}")
Batched Matrix Multiplication¶
import torch
# Batch of matrices: [batch_size, m, n]
A = torch.randn(10, 3, 4) # 10 matrices of size 3x4
B = torch.randn(10, 4, 5) # 10 matrices of size 4x5
# Batched matrix multiplication
C = torch.bmm(A, B) # Result: 10 matrices of size 3x5
print(f"Batch matmul shape: {C.shape}")
# Using matmul (more flexible)
C2 = torch.matmul(A, B)
print(f"torch.matmul shape: {C2.shape}")
Transpose and Permute¶
import torch
x = torch.randn(2, 3, 4)
print(f"Original shape: {x.shape}")
# Simple transpose (2D only)
y = torch.randn(3, 4)
y_T = y.T
print(f"Transpose shape: {y_T.shape}")
# Transpose specific dimensions
x_transposed = x.transpose(0, 1) # Swap dim 0 and 1
print(f"After transpose(0, 1): {x_transposed.shape}")
# Permute (rearrange all dimensions)
x_permuted = x.permute(2, 0, 1) # New order: (4, 2, 3)
print(f"After permute(2, 0, 1): {x_permuted.shape}")
# Common pattern: Convert (B, H, W, C) to (B, C, H, W)
image = torch.randn(32, 224, 224, 3) # Batch, Height, Width, Channels
image_pytorch = image.permute(0, 3, 1, 2) # Batch, Channels, Height, Width
print(f"Image format conversion: {image.shape} -> {image_pytorch.shape}")
Reshaping Operations¶
View and Reshape¶
import torch
x = torch.arange(12)
print(f"Original: {x}")
print(f"Shape: {x.shape}")
# Reshape to 3x4
x_3x4 = x.view(3, 4)
print(f"\nView (3, 4):\n{x_3x4}")
# Reshape to 2x6
x_2x6 = x.reshape(2, 6)
print(f"\nReshape (2, 6):\n{x_2x6}")
# Using -1 to infer dimension
x_auto = x.view(2, -1) # -1 is inferred as 6
print(f"\nView (2, -1):\n{x_auto}")
# Difference between view and reshape:
# - view(): requires contiguous memory, faster
# - reshape(): may copy data if not contiguous, safer
Squeeze and Unsqueeze¶
import torch
# Unsqueeze: Add dimension
x = torch.randn(3, 4)
print(f"Original shape: {x.shape}")
x_unsqueezed_0 = x.unsqueeze(0) # Add dimension at position 0
print(f"Unsqueeze(0): {x_unsqueezed_0.shape}") # [1, 3, 4]
x_unsqueezed_1 = x.unsqueeze(1) # Add dimension at position 1
print(f"Unsqueeze(1): {x_unsqueezed_1.shape}") # [3, 1, 4]
x_unsqueezed_minus1 = x.unsqueeze(-1) # Add at end
print(f"Unsqueeze(-1): {x_unsqueezed_minus1.shape}") # [3, 4, 1]
# Squeeze: Remove dimensions of size 1
y = torch.randn(1, 3, 1, 4)
print(f"\nOriginal shape: {y.shape}")
y_squeezed = y.squeeze() # Remove all dims of size 1
print(f"Squeeze(): {y_squeezed.shape}") # [3, 4]
y_squeezed_0 = y.squeeze(0) # Remove specific dim
print(f"Squeeze(0): {y_squeezed_0.shape}") # [3, 1, 4]
Flatten¶
import torch
# Flatten tensor
x = torch.randn(2, 3, 4)
print(f"Original shape: {x.shape}")
# Flatten all dimensions
x_flat = x.flatten()
print(f"Flatten(): {x_flat.shape}") # [24]
# Flatten from specific dimension
x_flat_from_1 = x.flatten(start_dim=1)
print(f"Flatten(start_dim=1): {x_flat_from_1.shape}") # [2, 12]
# Common in neural networks: flatten except batch dimension
batch_images = torch.randn(32, 3, 28, 28) # 32 images, 3 channels, 28x28
flattened = batch_images.flatten(start_dim=1)
print(f"Batch flatten: {batch_images.shape} -> {flattened.shape}")
Broadcasting¶
Broadcasting allows operations between tensors of different shapes.
Broadcasting Rules¶
- If tensors have different number of dimensions, pad the smaller one with 1s on the left
- Dimensions are compatible if they are equal or one of them is 1
- Tensors are broadcast to the larger size in each dimension
import torch
# Example 1: Scalar broadcasting
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
scalar = 10
result = x + scalar # Scalar broadcast to (2, 3)
print(f"Scalar broadcasting:\n{result}")
# Example 2: Vector broadcasting
x = torch.randn(3, 4)
y = torch.randn(4) # Will broadcast to (3, 4)
result = x + y
print(f"\nVector broadcast: {x.shape} + {y.shape} = {result.shape}")
# Example 3: Matrix broadcasting
x = torch.randn(3, 1)
y = torch.randn(1, 4)
result = x + y # Broadcasts to (3, 4)
print(f"Matrix broadcast: {x.shape} + {y.shape} = {result.shape}")
# Example 4: Batch operations
batch = torch.randn(32, 3, 28, 28) # Batch of images
mean = torch.randn(1, 3, 1, 1) # Channel-wise mean
normalized = batch - mean # Broadcasting
print(f"Batch broadcast: {batch.shape} - {mean.shape} = {normalized.shape}")
# Example 5: Explicit broadcast
x = torch.randn(3, 1, 5)
y = torch.randn(1, 4, 1)
# Manually check broadcasting
broadcasted_shape = torch.broadcast_shapes(x.shape, y.shape)
print(f"Broadcast shape: {x.shape} + {y.shape} -> {broadcasted_shape}")
Concatenation and Stacking¶
Concatenate¶
import torch
x = torch.tensor([[1, 2], [3, 4]])
y = torch.tensor([[5, 6], [7, 8]])
# Concatenate along dimension 0 (rows)
concat_0 = torch.cat([x, y], dim=0)
print(f"Concat dim=0:\n{concat_0}")
print(f"Shape: {concat_0.shape}") # [4, 2]
# Concatenate along dimension 1 (columns)
concat_1 = torch.cat([x, y], dim=1)
print(f"\nConcat dim=1:\n{concat_1}")
print(f"Shape: {concat_1.shape}") # [2, 4]
# Concatenate multiple tensors
z = torch.tensor([[9, 10], [11, 12]])
concat_3 = torch.cat([x, y, z], dim=0)
print(f"\nConcat 3 tensors:\n{concat_3}")
print(f"Shape: {concat_3.shape}") # [6, 2]
Stack¶
import torch
x = torch.tensor([1, 2, 3])
y = torch.tensor([4, 5, 6])
z = torch.tensor([7, 8, 9])
# Stack along new dimension 0
stacked_0 = torch.stack([x, y, z], dim=0)
print(f"Stack dim=0:\n{stacked_0}")
print(f"Shape: {stacked_0.shape}") # [3, 3]
# Stack along new dimension 1
stacked_1 = torch.stack([x, y, z], dim=1)
print(f"\nStack dim=1:\n{stacked_1}")
print(f"Shape: {stacked_1.shape}") # [3, 3]
# Difference between cat and stack:
# - cat: concatenates along existing dimension
# - stack: creates new dimension
Split and Chunk¶
import torch
x = torch.arange(12).reshape(4, 3)
print(f"Original:\n{x}")
# Split into specific sizes
splits = torch.split(x, 2, dim=0) # Split into chunks of size 2
print(f"\nSplit by size 2:")
for i, split in enumerate(splits):
print(f"Split {i}: shape {split.shape}\n{split}")
# Split into specific number of chunks
chunks = torch.chunk(x, 2, dim=0) # Split into 2 chunks
print(f"\nChunk into 2:")
for i, chunk in enumerate(chunks):
print(f"Chunk {i}: shape {chunk.shape}\n{chunk}")
# Split along columns
col_splits = torch.split(x, [1, 2], dim=1) # Split into sizes [1, 2]
print(f"\nColumn splits:")
for i, split in enumerate(col_splits):
print(f"Split {i}: shape {split.shape}\n{split}")
Reduction Operations¶
Sum, Mean, etc.¶
import torch
x = torch.tensor([[1, 2, 3],
[4, 5, 6]], dtype=torch.float32)
# Sum
print(f"Sum all: {x.sum()}")
print(f"Sum dim=0: {x.sum(dim=0)}") # Column sums
print(f"Sum dim=1: {x.sum(dim=1)}") # Row sums
# Mean
print(f"\nMean all: {x.mean()}")
print(f"Mean dim=0: {x.mean(dim=0)}")
print(f"Mean dim=1: {x.mean(dim=1)}")
# Max and Min
print(f"\nMax: {x.max()}")
print(f"Min: {x.min()}")
# Max with indices
max_val, max_idx = x.max(dim=0)
print(f"Max dim=0: values={max_val}, indices={max_idx}")
# Standard deviation
print(f"\nStd: {x.std()}")
# Variance
print(f"Var: {x.var()}")
# Product
print(f"Product all: {x.prod()}")
# Keepdim parameter
sum_keepdim = x.sum(dim=1, keepdim=True)
print(f"\nSum with keepdim: shape {sum_keepdim.shape}")
print(sum_keepdim)
Cumulative Operations¶
import torch
x = torch.tensor([1, 2, 3, 4, 5])
# Cumulative sum
cumsum = torch.cumsum(x, dim=0)
print(f"Cumsum: {cumsum}") # [1, 3, 6, 10, 15]
# Cumulative product
cumprod = torch.cumprod(x, dim=0)
print(f"Cumprod: {cumprod}") # [1, 2, 6, 24, 120]
# 2D example
x_2d = torch.tensor([[1, 2, 3],
[4, 5, 6]])
cumsum_0 = torch.cumsum(x_2d, dim=0)
print(f"\nCumsum dim=0:\n{cumsum_0}")
cumsum_1 = torch.cumsum(x_2d, dim=1)
print(f"Cumsum dim=1:\n{cumsum_1}")
Advanced Operations¶
Clipping and Clamping¶
import torch
x = torch.tensor([-2, -1, 0, 1, 2, 3, 4, 5])
# Clamp values between min and max
clamped = torch.clamp(x, min=0, max=3)
print(f"Clamped [0, 3]: {clamped}")
# Only min
clamped_min = torch.clamp(x, min=0)
print(f"Clamped min=0: {clamped_min}")
# Only max
clamped_max = torch.clamp(x, max=3)
print(f"Clamped max=3: {clamped_max}")
# In-place
x.clamp_(0, 3)
print(f"In-place clamped: {x}")
Masking and Where¶
import torch
x = torch.randn(3, 4)
print(f"Original:\n{x}")
# Create mask
mask = x > 0
print(f"\nMask (x > 0):\n{mask}")
# Apply mask
positive_vals = x[mask]
print(f"Positive values: {positive_vals}")
# Where: conditional selection
result = torch.where(x > 0, x, torch.zeros_like(x))
print(f"\nWhere (keep positive, else 0):\n{result}")
# Replace values
x_replaced = torch.where(x > 0, torch.tensor(1.0), torch.tensor(-1.0))
print(f"Replaced (+1 if >0, else -1):\n{x_replaced}")
Sorting and Ordering¶
import torch
x = torch.tensor([3, 1, 4, 1, 5, 9, 2, 6])
# Sort
sorted_vals, sorted_indices = torch.sort(x)
print(f"Sorted values: {sorted_vals}")
print(f"Sorted indices: {sorted_indices}")
# Descending order
sorted_desc, _ = torch.sort(x, descending=True)
print(f"Sorted descending: {sorted_desc}")
# Top-k values
top3_vals, top3_indices = torch.topk(x, 3)
print(f"Top 3 values: {top3_vals}")
print(f"Top 3 indices: {top3_indices}")
# Argsort
argsort = torch.argsort(x)
print(f"Argsort: {argsort}")
# 2D sorting
x_2d = torch.randint(0, 10, (3, 4))
print(f"\n2D tensor:\n{x_2d}")
sorted_2d, _ = torch.sort(x_2d, dim=1)
print(f"Sorted along dim=1:\n{sorted_2d}")
Practice Exercises¶
Exercise 1: Matrix Operations¶
import torch
# Create two random matrices
A = torch.randn(3, 4)
B = torch.randn(4, 5)
# Tasks:
# 1. Perform matrix multiplication
# 2. Compute element-wise square of A
# 3. Find mean of each row in A
# Solutions:
C = torch.mm(A, B)
A_squared = A ** 2
row_means = A.mean(dim=1)
Exercise 2: Reshaping¶
import torch
# Create tensor of shape (2, 3, 4)
x = torch.randn(2, 3, 4)
# Tasks:
# 1. Flatten to 1D
# 2. Reshape to (6, 4)
# 3. Add a batch dimension at the beginning
# 4. Permute to (4, 2, 3)
# Solutions:
flat = x.flatten()
reshaped = x.reshape(6, 4)
with_batch = x.unsqueeze(0)
permuted = x.permute(2, 0, 1)
Exercise 3: Broadcasting¶
import torch
# Normalize a batch of images
images = torch.randn(32, 3, 64, 64) # Batch, Channels, H, W
mean = torch.tensor([0.485, 0.456, 0.406]).reshape(1, 3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).reshape(1, 3, 1, 1)
# Normalize
normalized = (images - mean) / std
print(f"Normalized shape: {normalized.shape}")
Performance Tips¶
import torch
import time
# 1. Use in-place operations to save memory
x = torch.randn(1000, 1000)
x.add_(5) # Better than x = x + 5
# 2. Avoid unnecessary copies
y = x.view(-1) # No copy
z = x.reshape(-1) # May copy
# 3. Vectorize operations
# Bad: Loop
result = []
for i in range(len(x)):
result.append(x[i] * 2)
# Good: Vectorized
result = x * 2
# 4. Use appropriate dtype
x_float16 = torch.randn(1000, 1000, dtype=torch.float16) # Faster on GPU
Next Steps¶
Continue to Chapter 4: Autograd & Gradients to learn about: - Automatic differentiation - Computing gradients - Gradient descent - Building custom functions
Key Takeaways¶
- ✅ PyTorch supports extensive mathematical operations
- ✅ Broadcasting enables operations on different-shaped tensors
- ✅ Reshape operations: view, reshape, squeeze, unsqueeze, flatten
- ✅ Use in-place operations (_) for memory efficiency
- ✅ Reduction operations: sum, mean, max, min with keepdim option
- ✅ Matrix operations: mm, bmm, matmul for different scenarios
Recommended Reads¶
📚 Official Documentation
- PyTorch Operations - Complete operations reference
- Tensor Operations - Mathematical operations
- Matrix Operations - Linear algebra functions
- Reduction Operations - Sum, mean, max, min, etc.
📖 Essential Articles
- Understanding Broadcasting - Broadcasting rules and examples
- Tensor Reshaping Guide - View, reshape, and memory layout
- Efficient Tensor Operations - Performance optimization
- Matrix Multiplication Guide - When to use mm, bmm, matmul
🎓 Learning Resources
- NumPy Operations in PyTorch - NumPy to PyTorch mapping
- Linear Algebra with PyTorch - Advanced tensor operations
- Broadcasting Tutorial - Broadcasting examples
💡 Best Practices
- In-Place Operations - When to use in-place ops
- Memory Efficient Operations - Reducing memory usage
- Vectorization Best Practices - Optimizing operations
🔬 Research Papers
- Array Programming with NumPy - Understanding array operations
- BLAS and LAPACK - Linear algebra libraries