Attention is all you need, this paper, written way back on 12 Jun 2017, now up to its v7th inception, is a concept that has changed AI Models, in particular the Transformer Models more than any other concept written about.
One must pay attention when trying to learn something, the same is true in the world of AI!
In 2017, the paper: "Attention is all you need" was released and this changed the world of AI! Over night, we went from large bulky LSTM Networks for Sequence to Sequence tasks, to the Transformer model, which outperformed LSTM's by a lot.
Credit to YouTube Channel: Concepts Illuminated, for the following videos, they are by far the best Videos explaining Attention, in my opinion:
Learning to weight, more or less heavily, a particular item, gives you a form of Attention. There are a few different types of Attention today, here is a basic class giving some of the implementations:
namespace AI
{
#region Using Statements:
using System;
using System.Threading;
using System.Threading.Tasks;
#endregion
/// <summary>
/// The Attention Mechanism supporting Multi-head, Self-Attention and Cross-Attention.
/// </summary>
[Serializable]
public class Attention
{
#region Fields:
/// <summary>
/// Dimensionality of the model.
/// </summary>
private readonly int _dModel;
/// <summary>
/// Number of attention heads.
/// </summary>
private readonly int _numHeads;
/// <summary>
/// Dimensionality per attention head.
/// </summary>
private readonly int _dHead;
/// <summary>
/// Query weight matrix [dModel, dModel].
/// </summary>
private Matrix<double> _Wq;
/// <summary>
/// Key weight matrix [dModel, dModel].
/// </summary>
private Matrix<double> _Wk;
/// <summary>
/// Value weight matrix [dModel, dModel].
/// </summary>
private Matrix<double> _Wv;
/// <summary>
/// Output weight matrix [dModel, dModel].
/// </summary>
private Matrix<double> _Wo;
/// <summary>
/// Gradient of query weight matrix.
/// </summary>
private Matrix<double> _dWq;
/// <summary>
/// Gradient of key weight matrix.
/// </summary>
private Matrix<double> _dWk;
/// <summary>
/// Gradient of value weight matrix.
/// </summary>
private Matrix<double> _dWv;
/// <summary>
/// Gradient of output weight matrix.
/// </summary>
private Matrix<double> _dWo;
/// <summary>
/// Maximum sequence length.
/// </summary>
private readonly int _maxSeqLen;
// Adam optimizer state variables
/// <summary>
/// First moment estimate for _Wq.
/// </summary>
private Matrix<double> _mWq;
/// <summary>
/// Second moment estimate for _Wq.
/// </summary>
private Matrix<double> _vWq;
/// <summary>
/// First moment estimate for _Wk.
/// </summary>
private Matrix<double> _mWk;
/// <summary>
/// Second moment estimate for _Wk.
/// </summary>
private Matrix<double> _vWk;
/// <summary>
/// First moment estimate for _Wv.
/// </summary>
private Matrix<double> _mWv;
/// <summary>
/// Second moment estimate for _Wv.
/// </summary>
private Matrix<double> _vWv;
/// <summary>
/// First moment estimate for _Wo.
/// </summary>
private Matrix<double> _mWo;
/// <summary>
/// Second moment estimate for _Wo.
/// </summary>
private Matrix<double> _vWo;
#endregion
/// <summary>
/// Initializes multi-head attention with specified dimensions.
/// </summary>
/// <param name="dModel">Dimensionality of the model.</param>
/// <param name="numHeads">Number of attention heads.</param>
/// <param name="maxSeqLen">Maximum sequence length.</param>
/// <exception cref="ArgumentException">Thrown if dModel is not divisible by numHeads.</exception>
public Attention(int dModel, int numHeads, int maxSeqLen)
{
if (dModel % numHeads != 0)
throw new ArgumentException("dModel must be divisible by numHeads.");
_dModel = dModel;
_numHeads = numHeads;
_dHead = dModel / numHeads;
_maxSeqLen = maxSeqLen;
_Wq = Matrix<double>.InitializeXavier(_dModel, _dModel);
_Wk = Matrix<double>.InitializeXavier(_dModel, _dModel);
_Wv = Matrix<double>.InitializeXavier(_dModel, _dModel);
_Wo = Matrix<double>.InitializeXavier(_dModel, _dModel);
// Initialize Adam state
_mWq = new Matrix<double>(_Wq.Rows, _Wq.Columns);
_vWq = new Matrix<double>(_Wq.Rows, _Wq.Columns);
_mWk = new Matrix<double>(_Wk.Rows, _Wk.Columns);
_vWk = new Matrix<double>(_Wk.Rows, _Wk.Columns);
_mWv = new Matrix<double>(_Wv.Rows, _Wv.Columns);
_vWv = new Matrix<double>(_Wv.Rows, _Wv.Columns);
_mWo = new Matrix<double>(_Wo.Rows, _Wo.Columns);
_vWo = new Matrix<double>(_Wo.Rows, _Wo.Columns);
}
/// <summary>
/// Forward pass through multi-head attention.
/// </summary>
/// <param name="q">Query matrix [batchSize * qSeqLen, dModel].</param>
/// <param name="k">Key matrix [batchSize * kSeqLen, dModel].</param>
/// <param name="v">Value matrix [batchSize * kSeqLen, dModel].</param>
/// <param name="mask">Optional mask [qSeqLen, kSeqLen], e.g., causal mask.</param>
/// <returns>Attention output [batchSize * qSeqLen, dModel].</returns>
public Matrix<double> Forward(Matrix<double> q, Matrix<double> k, Matrix<double> v, Action<string> logger, Matrix<double> mask = null)
{
// Input validation
if (q == null || k == null || v == null || q.Columns != _dModel || k.Columns != _dModel || v.Columns != _dModel)
{
logger?.Invoke($"Error: Invalid inputs (q={q?.Rows}x{q?.Columns}, k={k?.Rows}x{k?.Columns}, v={v?.Rows}x{v?.Columns}, expected columns={_dModel}). Returning zero matrix.");
return new Matrix<double>(q?.Rows ?? 0, _dModel);
}
if (q.Rows == 0 || k.Rows == 0 || v.Rows == 0)
{
logger?.Invoke($"Warning: Empty inputs (q.Rows={q.Rows}, k.Rows={k.Rows}, v.Rows={v.Rows}). Returning zero matrix.");
return new Matrix<double>(q.Rows, _dModel);
}
// logger?.Invoke($"Input shapes: q=[{q.Rows},{q.Columns}], k=[{k.Rows},{k.Columns}], v=[{v.Rows},{v.Columns}], mask=[{mask?.Rows ?? -1},{mask?.Columns ?? -1}], _maxSeqLen={_maxSeqLen}");
// Determine sequence lengths and batch size
int qSeqLen, kSeqLen, batchSize;
if (mask != null)
{
if (mask.Rows <= 0 || mask.Columns <= 0)
{
logger?.Invoke($"Error: Mask invalid (Rows={mask.Rows}, Columns={mask.Columns}). Returning zero matrix.");
return new Matrix<double>(q.Rows, _dModel);
}
qSeqLen = mask.Rows;
kSeqLen = mask.Columns;
batchSize = q.Rows / qSeqLen;
if (q.Rows % qSeqLen != 0 || batchSize <= 0 || k.Rows != batchSize * kSeqLen || v.Rows != batchSize * kSeqLen)
{
logger?.Invoke($"Error: Dimension mismatch (q.Rows={q.Rows}, k.Rows={k.Rows}, v.Rows={v.Rows}, mask={mask.Rows}x{mask.Columns}). Returning zero matrix.");
return new Matrix<double>(q.Rows, _dModel);
}
}
else
{
if (_maxSeqLen <= 0)
{
logger?.Invoke($"Error: _maxSeqLen ({_maxSeqLen}) invalid. Returning zero matrix.");
return new Matrix<double>(q.Rows, _dModel);
}
batchSize = Math.Max(1, Math.Min(q.Rows, k.Rows) / _maxSeqLen);
qSeqLen = q.Rows / batchSize;
kSeqLen = k.Rows / batchSize;
if (qSeqLen <= 0 || qSeqLen > _maxSeqLen || q.Rows % batchSize != 0 ||
kSeqLen <= 0 || kSeqLen > _maxSeqLen || k.Rows % batchSize != 0 || v.Rows != k.Rows)
{
logger?.Invoke($"Warning: Inconsistent dimensions (q.Rows={q.Rows}, k.Rows={k.Rows}, v.Rows={v.Rows}, _maxSeqLen={_maxSeqLen}). Falling back to batchSize=1.");
batchSize = 1;
qSeqLen = q.Rows;
kSeqLen = k.Rows;
if (qSeqLen > _maxSeqLen || kSeqLen > _maxSeqLen || v.Rows != k.Rows)
{
logger?.Invoke($"Error: Sequence lengths exceed _maxSeqLen (qSeqLen={qSeqLen}, kSeqLen={kSeqLen}, _maxSeqLen={_maxSeqLen}). Returning zero matrix.");
return new Matrix<double>(q.Rows, _dModel);
}
}
}
// logger?.Invoke($"Computed: batchSize={batchSize}, qSeqLen={qSeqLen}, kSeqLen={kSeqLen}");
// Compute Q, K, V
Matrix<double> Q = q * _Wq;
Matrix<double> K = k * _Wk;
Matrix<double> V = v * _Wv;
if (Q.Rows != q.Rows || K.Rows != k.Rows || V.Rows != v.Rows || Q.Columns != _dModel || K.Columns != _dModel || V.Columns != _dModel)
{
logger?.Invoke($"Error: QKV computation failed (Q=[{Q.Rows},{Q.Columns}], K=[{K.Rows},{K.Columns}], V=[{V.Rows},{V.Columns}]). Returning zero matrix.");
return new Matrix<double>(q.Rows, _dModel);
}
var headOutputs = new Matrix<double>[_numHeads];
Parallel.For(0, _numHeads, h =>
{
try
{
var qHead = Q.SubMatrix(0, Q.Rows, h * _dHead, _dHead);
var kHead = K.SubMatrix(0, K.Rows, h * _dHead, _dHead);
var vHead = V.SubMatrix(0, V.Rows, h * _dHead, _dHead);
var output = new Matrix<double>(q.Rows, _dHead);
Parallel.For(0, batchSize, b =>
{
try
{
var qBatch = qHead.SubMatrix(b * qSeqLen, qSeqLen, 0, _dHead);
var kBatch = kHead.SubMatrix(b * kSeqLen, kSeqLen, 0, _dHead);
var scoresBatch = (qBatch * kBatch.Transpose()) * (1.0 / Math.Sqrt(_dHead));
if (mask != null)
{
if (scoresBatch.Rows != mask.Rows || scoresBatch.Columns != mask.Columns)
{
logger?.Invoke($"Error: Mask size mismatch (scores=[{scoresBatch.Rows},{scoresBatch.Columns}], mask=[{mask.Rows},{mask.Columns}]) in head {h}, batch {b}.");
return;
}
scoresBatch = scoresBatch + mask;
}
var weights = Softmax(scoresBatch);
var vBatch = vHead.SubMatrix(b * kSeqLen, kSeqLen, 0, _dHead);
var outputBatch = weights * vBatch;
if (outputBatch.Rows != qSeqLen || outputBatch.Columns != _dHead)
{
logger?.Invoke($"Error: Output batch size incorrect (outputBatch=[{outputBatch.Rows},{outputBatch.Columns}], expected=[{qSeqLen},{_dHead}]) in head {h}, batch {b}.");
return;
}
for (int i = 0; i < qSeqLen; i++)
for (int j = 0; j < _dHead; j++)
output[b * qSeqLen + i, j] = outputBatch[i, j];
}
catch (Exception ex)
{
logger?.Invoke($"Error in inner Parallel.For (head={h}, batch={b}): {ex.Message}");
}
});
headOutputs[h] = output;
}
catch (Exception ex)
{
logger?.Invoke($"Error in outer Parallel.For (head={h}): {ex.Message}");
headOutputs[h] = null;
}
});
if (headOutputs.Any(h => h == null))
{
logger?.Invoke("Error: One or more attention heads failed to compute. Returning zero matrix.");
return new Matrix<double>(q.Rows, _dModel);
}
var combined = ConcatenateHeads(headOutputs);
var result = combined * _Wo;
if (result.Rows != q.Rows || result.Columns != _dModel)
{
logger?.Invoke($"Error: Final output shape incorrect (result=[{result.Rows},{result.Columns}], expected=[{q.Rows},{_dModel}]). Returning zero matrix.");
return new Matrix<double>(q.Rows, _dModel);
}
if (result.Any(double.IsNaN) || result.Any(double.IsInfinity))
{
logger?.Invoke("Warning: Output contains NaN or Infinity values.");
}
// logger?.Invoke($"Output shape: [{result.Rows},{result.Columns}]");
return result;
}
/// <summary>
/// Backward pass through multi-head attention.
/// </summary>
/// <param name="dOutput">Gradient w.r.t. output [batchSize * qSeqLen, dModel].</param>
/// <param name="q">Query matrix [batchSize * qSeqLen, dModel].</param>
/// <param name="k">Key matrix [batchSize * kSeqLen, dModel].</param>
/// <param name="v">Value matrix [batchSize * kSeqLen, dModel].</param>
/// <param name="mask">Optional mask [qSeqLen, kSeqLen].</param>
/// <returns>Tuple of gradients w.r.t. query, key, and value matrices.</returns>
public (Matrix<double> dQ, Matrix<double> dK, Matrix<double> dV) Backward(
Matrix<double> dOutput, Matrix<double> q, Matrix<double> k, Matrix<double> v, Matrix<double> mask = null)
{
int batchSize = q.Rows / (mask != null ? mask.Rows : _maxSeqLen);
int qSeqLen = q.Rows / batchSize;
int kSeqLen = k.Rows / batchSize;
var dConcat = dOutput * _Wo.Transpose();
_dWo = dConcat.Transpose() * dOutput;
var heads = new Matrix<double>[_numHeads];
Parallel.For(0, _numHeads, h =>
{
heads[h] = dConcat.SubMatrix(0, dConcat.Rows, h * _dHead, _dHead);
});
var dQProj = new Matrix<double>(q.Rows, _dModel);
var dKProj = new Matrix<double>(k.Rows, _dModel);
var dVProj = new Matrix<double>(v.Rows, _dModel);
var Q = q * _Wq;
var K = k * _Wk;
var V = v * _Wv;
Parallel.For(0, _numHeads, h =>
{
var dHead = heads[h];
var qHead = Q.SubMatrix(0, Q.Rows, h * _dHead, _dHead);
var kHead = K.SubMatrix(0, K.Rows, h * _dHead, _dHead);
var vHead = V.SubMatrix(0, V.Rows, h * _dHead, _dHead);
for (int b = 0; b < batchSize; b++)
{
int startIdxQ = b * qSeqLen;
int startIdxK = b * kSeqLen;
var dHeadBatch = dHead.SubMatrix(startIdxQ, qSeqLen, 0, _dHead);
var qHeadBatch = qHead.SubMatrix(startIdxQ, qSeqLen, 0, _dHead);
var kHeadBatch = kHead.SubMatrix(startIdxK, kSeqLen, 0, _dHead);
var vHeadBatch = vHead.SubMatrix(startIdxK, kSeqLen, 0, _dHead);
var scoresBatch = (qHeadBatch * kHeadBatch.Transpose()) * (1.0 / Math.Sqrt(_dHead));
if (mask != null)
scoresBatch = scoresBatch + mask;
var attnBatch = Softmax(scoresBatch);
var dVHeadBatch = attnBatch.Transpose() * dHeadBatch;
var dAttnBatch = dHeadBatch * vHeadBatch.Transpose();
var dScoresBatch = SoftmaxBackward(dAttnBatch, attnBatch);
var dQHeadBatch = (dScoresBatch * kHeadBatch) * (1.0 / Math.Sqrt(_dHead));
var dKHeadBatch = (dScoresBatch.Transpose() * qHeadBatch) * (1.0 / Math.Sqrt(_dHead));
lock (dQProj)
{
for (int i = 0; i < qSeqLen; i++)
for (int j = 0; j < _dHead; j++)
dQProj[startIdxQ + i, h * _dHead + j] += dQHeadBatch[i, j];
}
lock (dKProj)
{
for (int i = 0; i < kSeqLen; i++)
for (int j = 0; j < _dHead; j++)
dKProj[startIdxK + i, h * _dHead + j] += dKHeadBatch[i, j];
}
lock (dVProj)
{
for (int i = 0; i < kSeqLen; i++)
for (int j = 0; j < _dHead; j++)
dVProj[startIdxK + i, h * _dHead + j] += dVHeadBatch[i, j];
}
}
});
var dQ = dQProj * _Wq.Transpose();
var dK = dKProj * _Wk.Transpose();
var dV = dVProj * _Wv.Transpose();
_dWq = q.Transpose() * dQProj;
_dWk = k.Transpose() * dKProj;
_dWv = v.Transpose() * dVProj;
return (dQ, dK, dV);
}
/// <summary>
/// Updates attention weight matrices with Adam optimization.
/// </summary>
/// <param name="adam">The Adam optimizer instance.</param>
/// <param name="t">The current timestep for bias correction.</param>
public void UpdateParametersWithAdam(AdamOptimizer adam, int t)
{
(_mWq, _vWq, _Wq) = adam.Update(_Wq, _dWq, _mWq, _vWq, t);
(_mWk, _vWk, _Wk) = adam.Update(_Wk, _dWk, _mWk, _vWk, t);
(_mWv, _vWv, _Wv) = adam.Update(_Wv, _dWv, _mWv, _vWv, t);
(_mWo, _vWo, _Wo) = adam.Update(_Wo, _dWo, _mWo, _vWo, t);
}
/// <summary>
/// Computes the gradient norm for attention weights.
/// </summary>
/// <returns>Gradient norm as a scalar value.</returns>
public double GetGradientNorm() => Math.Sqrt(_dWq.ElementWiseMultiply(_dWq).Sum() +
_dWk.ElementWiseMultiply(_dWk).Sum() +
_dWv.ElementWiseMultiply(_dWv).Sum() +
_dWo.ElementWiseMultiply(_dWo).Sum());
/// <summary>
/// Concatenates outputs from all attention heads.
/// </summary>
/// <param name="heads">Array of head outputs [numHeads][batchSize * qSeqLen, dHead].</param>
/// <returns>Concatenated matrix [batchSize * qSeqLen, dModel].</returns>
private Matrix<double> ConcatenateHeads(Matrix<double>[] heads)
{
var result = new Matrix<double>(heads[0].Rows, _dModel);
Parallel.For(0, heads[0].Rows, i =>
{
for (int h = 0; h < _numHeads; h++)
for (int j = 0; j < _dHead; j++)
result[i, h * _dHead + j] = heads[h][i, j];
});
return result;
}
/// <summary>
/// Applies softmax to attention scores along each row.
/// </summary>
/// <param name="input">Attention scores [qSeqLen, kSeqLen].</param>
/// <returns>Softmax probabilities [qSeqLen, kSeqLen].</returns>
private Matrix<double> Softmax(Matrix<double> input)
{
var result = new Matrix<double>(input.Rows, input.Columns);
Parallel.For(0, input.Rows, i =>
{
double max = input.Row(i).Max();
var exp = input.Row(i).AddScalar(-max).Exp();
double sum = exp.Sum() + 1e-10; // Larger epsilon
for (int j = 0; j < input.Columns; j++)
result[i, j] = exp[0, j] / sum;
});
return result;
}
/// <summary>
/// Computes the gradient of the softmax operation.
/// </summary>
/// <param name="dOutput">Gradient w.r.t. softmax output [qSeqLen, kSeqLen].</param>
/// <param name="probs">Softmax probabilities [qSeqLen, kSeqLen].</param>
/// <returns>Gradient w.r.t. softmax input [qSeqLen, kSeqLen].</returns>
private Matrix<double> SoftmaxBackward(Matrix<double> dOutput, Matrix<double> probs)
{
var dScores = new Matrix<double>(probs.Rows, probs.Columns);
Parallel.For(0, probs.Rows, i =>
{
for (int j = 0; j < probs.Columns; j++)
{
double sum = 0;
for (int k = 0; k < probs.Columns; k++)
sum += dOutput[i, k] * probs[i, k] * ((j == k ? 1 : 0) - probs[i, j]);
dScores[i, j] = sum;
}
});
return dScores;
}
/// <summary>
/// Scales all gradients in the multi-head attention mechanism, including query, key, value,
/// and output weight gradients.
/// </summary>
/// <param name="scale">The scaling factor to apply to all gradients (e.g., for gradient clipping).</param>
/// <exception cref="ArgumentException">Thrown if scale is NaN, infinite, or negative.</exception>
public void ScaleGradients(double scale)
{
if (double.IsNaN(scale) || double.IsInfinity(scale) || scale < 0)
throw new ArgumentException("Scale must be a non-negative finite number.", nameof(scale));
_dWq *= scale;
_dWk *= scale;
_dWv *= scale;
_dWo *= scale;
}
}
}
And the usage:
using NeuralNetworks;
class Program
{
static void Main()
{
// Create an Attention object with vector size 4
var attention = new Attention(dModel: 4);
// Dummy data: imagine these are word vectors
double[][] query = new double[][] { new double[] { 1, 0, 0, 1 } }; // One asking word
double[][] keys = new double[][]
{
new double[] { 1, 0, 0, 1 }, // Word 1
new double[] { 0, 1, 1, 0 } // Word 2
};
double[][] values = keys; // Often keys and values are the same in self-attention
// Try different attention types
Console.WriteLine("Dot-Product Attention:");
var dotResult = attention.DotProductAttention(query, keys, values);
PrintMatrix(dotResult);
Console.WriteLine("Scaled Dot-Product Attention:");
var scaledResult = attention.ScaledDotProductAttention(query, keys, values);
PrintMatrix(scaledResult);
Console.WriteLine("Additive Attention:");
var additiveResult = attention.AdditiveAttention(query, keys, values);
PrintMatrix(additiveResult);
Console.WriteLine("Multi-Head Attention (2 heads):");
var multiHeadResult = attention.MultiHeadAttention(query, keys, values, numHeads: 2);
PrintMatrix(multiHeadResult);
}
static void PrintMatrix(double[][] matrix)
{
foreach (var row in matrix)
Console.WriteLine($"[{string.Join(", ", row)}]");
}
}
Using the following class if one wanted to make use of Matrix and Vector terminology's:
namespace AI
{
#region Using Statements:
using System;
using System.Linq;
using System.Threading.Tasks;
#endregion
/// <summary>
/// Represents a matrix of numerical values for efficient linear algebra operations, optimized for use in machine learning models like Transformers.
/// Supports generic numeric types (float or double) with parallelized operations for speed and thread safety.
/// </summary>
/// <typeparam name="T">The numeric type, constrained to float or double, implementing IComparable and IEquatable for comparisons and equality checks.</typeparam>
/// <remarks>
/// This matrix is a 2D array internally, designed for operations like addition, multiplication, transposition, and normalization.
/// It uses jagged arrays for memory efficiency and parallel processing for performance, with thread-safe random number generation where applicable.
/// </remarks>
public class Matrix<T> where T : struct, IComparable<T>, IEquatable<T>
{
#region Fields:
/// <summary>
/// Threshold for switching to parallel operations; tune based on testing for optimal performance.
/// </summary>
private const int ParallelThreshold = 100;
/// <summary>
/// Internal jagged array storing the matrix data, where each row is an array of type T.
/// </summary>
public T[][] _data;
/// <summary>
/// Number of rows in the matrix.
/// </summary>
private readonly int _rows;
/// <summary>
/// Number of columns in the matrix.
/// </summary>
private readonly int _cols;
/// <summary>
/// Indicates whether T is float, used for type-specific operations to optimize performance.
/// </summary>
private static readonly bool IsFloat = typeof(T) == typeof(float);
/// <summary>
/// Indicates whether T is double, used for type-specific operations to optimize precision.
/// </summary>
private static readonly bool IsDouble = typeof(T) == typeof(double);
/// <summary>
/// Gets the number of rows in the matrix.
/// </summary>
public int Rows => _rows;
/// <summary>
/// Gets the number of columns in the matrix.
/// </summary>
public int Columns => _cols;
/// <summary>
/// _sync Lock object.
/// </summary>
private readonly object _syncLock = new object();
private static readonly ThreadLocal<Random> _random = new ThreadLocal<Random>(() => new Random(Guid.NewGuid().GetHashCode()));
#endregion
#region Properties:
/// <summary>
/// Gets or sets the value at the specified row and column index.
/// </summary>
/// <param name="row">Row index, must be between 0 and Rows-1.</param>
/// <param name="col">Column index, must be between 0 and Columns-1.</param>
/// <returns>The value at the specified position.</returns>
/// <exception cref="IndexOutOfRangeException">Thrown if row or col is out of bounds.</exception>
/// <remarks>
/// Provides direct access to matrix elements with bounds checking for safety.
/// </remarks>
public T this[int row, int col]
{
get
{
if (row < 0 || row >= _rows || col < 0 || col >= _cols)
throw new IndexOutOfRangeException($"Index [{row}, {col}] is out of bounds for matrix [{_rows}, {_cols}].");
return _data[row][col];
}
set
{
if (row < 0 || row >= _rows || col < 0 || col >= _cols)
throw new IndexOutOfRangeException($"Index [{row}, {col}] is out of bounds for matrix [{_rows}, {_cols}].");
lock (_syncLock)
_data[row][col] = value;
}
}
#endregion
/// <summary>
/// Initializes a new matrix with the specified dimensions, filled with zeros.
/// </summary>
/// <param name="rows">Number of rows, must be positive.</param>
/// <param name="cols">Number of columns, must be positive.</param>
/// <exception cref="ArgumentException">Thrown if rows or cols are less than or equal to zero.</exception>
/// <remarks>
/// Uses Parallel.For to initialize rows efficiently, ensuring thread safety as each row is independent.
/// </remarks>
public Matrix(int rows, int cols)
{
if (rows <= 0 || cols <= 0)
throw new ArgumentException($"Matrix dimensions must be positive, got rows={rows}, cols={cols}.");
_rows = rows;
_cols = cols;
_data = new T[rows][];
Parallel.For(0, rows, i => _data[i] = new T[cols]);
}
/// <summary>
/// Initializes a matrix from a provided 2D jagged array.
/// </summary>
/// <param name="data">Jagged array of values, must be non-null, non-empty, and rectangular.</param>
/// <exception cref="ArgumentNullException">Thrown if data is null.</exception>
/// <exception cref="ArgumentException">Thrown if data is empty or has inconsistent row lengths.</exception>
/// <remarks>
/// Copies the input data in parallel to ensure the internal array is independent of the input, avoiding external modifications.
/// </remarks>
public Matrix(T[][] data)
{
if (data == null)
throw new ArgumentNullException(nameof(data), "Input data cannot be null.");
if (data.Length == 0 || data[0].Length == 0)
throw new ArgumentException("Input data must have at least one row and column.");
if (!data.All(row => row.Length == data[0].Length))
throw new ArgumentException("All rows in the input data must have the same length.");
_rows = data.Length;
_cols = data[0].Length;
_data = new T[_rows][];
Parallel.For(0, _rows, i =>
{
_data[i] = new T[_cols];
Array.Copy(data[i], _data[i], _cols);
});
}
/// <summary>
/// Computes the minimum value across all elements in the matrix.
/// </summary>
/// <returns>The smallest value in the matrix as type T.</returns>
/// <exception cref="InvalidOperationException">Thrown if the matrix is empty (has no rows or columns).</exception>
/// <remarks>
/// Iterates over all elements in parallel to find the minimum value, leveraging the Compare method for type safety.
/// Useful for determining the lower bound of matrix values in optimization or normalization tasks.
/// </remarks>
public T Min()
{
if (Rows == 0 || Columns == 0)
throw new InvalidOperationException("Cannot compute minimum of an empty matrix.");
T globalMin = default; // Will be set in first iteration
object lockObj = new object();
bool first = true;
Parallel.For(0, Rows, () => _data[0][0], (i, state, localMin) =>
{
T min = _data[i][0];
for (int j = 1; j < Columns; j++)
if (Compare(_data[i][j], min) < 0)
min = _data[i][j];
return min;
}, localMin =>
{
lock (lockObj)
{
if (first) { globalMin = localMin; first = false; }
else if (Compare(localMin, globalMin) < 0)
globalMin = localMin;
}
});
return globalMin;
}
/// <summary>
/// Computes the maximum value across all elements in the matrix.
/// </summary>
/// <returns>The largest value in the matrix as type T.</returns>
/// <exception cref="InvalidOperationException">Thrown if the matrix is empty (has no rows or columns).</exception>
/// <remarks>
/// Iterates over all elements in parallel to find the maximum value, leveraging the Compare method for type safety.
/// Useful for determining the upper bound of matrix values in optimization or normalization tasks.
/// </remarks>
public T Max()
{
if (Rows == 0 || Columns == 0)
throw new InvalidOperationException("Cannot compute maximum of an empty matrix.");
T[] rowMaxes = new T[Rows];
Parallel.For(0, Rows, i =>
{
T max = _data[i][0];
for (int j = 1; j < Columns; j++)
if (Compare(_data[i][j], max) > 0)
max = _data[i][j];
rowMaxes[i] = max;
});
T globalMax = rowMaxes[0];
for (int i = 1; i < Rows; i++)
if (Compare(rowMaxes[i], globalMax) > 0)
globalMax = rowMaxes[i];
return globalMax;
}
/// <summary>
/// Retrieves the token index with the highest logit value for a specific sequence position.
/// </summary>
/// <param name="logits">Matrix of logit values, where rows represent sequence positions and columns represent token probabilities.</param>
/// <param name="currentSeqLen">The current sequence length, used to determine the row to inspect (1-based index).</param>
/// <returns>The column index (token ID) of the maximum logit value in the specified row, or -1 if an error occurs.</returns>
/// <exception cref="ArgumentNullException">Thrown if logits is null.</exception>
/// <exception cref="ArgumentOutOfRangeException">Thrown if currentSeqLen is less than 1 or exceeds the number of rows in logits.</exception>
/// <remarks>
/// Uses ArgMaxInRow to find the maximum value's index in the row at (currentSeqLen - 1). Includes additional validation and logging for debugging.
/// </remarks>
public static int GetNextToken(Matrix<double> logits, int currentSeqLen)
{
if (logits == null)
throw new ArgumentNullException(nameof(logits), "Logits matrix cannot be null.");
if (logits.Rows == 0 || logits.Columns == 0)
{
Console.WriteLine($"Error: Logits matrix is empty [Rows={logits.Rows}, Columns={logits.Columns}]");
return -1; // Fallback value
}
if (currentSeqLen < 0 || currentSeqLen > logits.Rows)
{
Console.WriteLine($"Error: currentSeqLen={currentSeqLen} is out of range for logits.Rows={logits.Rows}");
return -1; // Fallback value
}
try
{
int[] maxIndices = logits.ArgMaxInRow();
if (maxIndices == null || maxIndices.Length != logits.Rows)
{
Console.WriteLine($"Error: ArgMaxInRow returned invalid result (null or length mismatch, expected {logits.Rows}, got {maxIndices?.Length})");
return -1;
}
int rowIndex = currentSeqLen;
return maxIndices[rowIndex];
}
catch (Exception ex)
{
Console.WriteLine($"Exception in GetNextToken: {ex.Message} at row {currentSeqLen - 1}, logits dimensions [{logits.Rows}, {logits.Columns}]");
return -1; // Fallback value
}
}
/// <summary>
/// Finds the index of the maximum value in each row of the matrix.
/// </summary>
/// <returns>An array of integers, where each element is the column index of the maximum value in the corresponding row.</returns>
/// <exception cref="InvalidOperationException">Thrown if the matrix has no rows.</exception>
/// <remarks>
/// Parallelized over rows for efficiency, using Compare for type-safe comparisons.
/// Identical to ArgMax but explicitly named ArgMaxInRow for clarity and to match original naming conventions.
/// Commonly used in classification tasks to identify the predicted class index per row.
/// </remarks>
public int[] ArgMaxInRow()
{
if (Rows == 0)
throw new InvalidOperationException("Cannot compute ArgMaxInRow on an empty matrix.");
int[] result = new int[Rows];
Parallel.For(0, Rows, i =>
{
int maxIdx = 0;
T maxVal = _data[i][0];
for (int j = 1; j < Columns; j++)
{
if (Compare(_data[i][j], maxVal) > 0)
{
maxVal = _data[i][j];
maxIdx = j;
}
}
result[i] = maxIdx;
});
return result;
}
/// <summary>
/// Normalizes an array of values to have an L2 norm of 1, scaling all elements proportionally.
/// </summary>
/// <param name="array">The input array to normalize, must be non-null and non-empty.</param>
/// <returns>A new array with the same length, normalized to unit length.</returns>
/// <exception cref="ArgumentNullException">Thrown if the input array is null.</exception>
/// <exception cref="ArgumentException">Thrown if the input array is empty.</exception>
/// <remarks>
/// Computes the L2 norm (sqrt(sum of squares)) and divides each element by it.
/// Returns a zero array if the norm is zero to avoid division issues. Used internally for vector normalization.
/// </remarks>
public static T[] NormalizeArray(T[] array)
{
if (array == null)
throw new ArgumentNullException(nameof(array), "Input array cannot be null.");
if (array.Length == 0)
throw new ArgumentException("Input array must not be empty.", nameof(array));
T sumOfSquares = Zero();
object lockObj = new object();
Parallel.For(0, array.Length, () => Zero(), (i, state, localSum) =>
Add(localSum, Multiply(array[i], array[i])), localSum =>
{
lock (lockObj)
sumOfSquares = Add(sumOfSquares, localSum);
});
T norm = Sqrt(sumOfSquares);
T[] result = new T[array.Length];
if (Compare(norm, Zero()) == 0)
return result;
Parallel.For(0, array.Length, i => result[i] = Divide(array[i], norm));
return result;
}
/// <summary>
/// Initializes a matrix using Xavier (Glorot) initialization for neural network weights.
/// </summary>
/// <param name="rows">Number of rows in the matrix, typically the output size.</param>
/// <param name="cols">Number of columns in the matrix, typically the input size.</param>
/// <returns>A new Matrix<T> initialized with Xavier values.</returns>
/// <exception cref="ArgumentException">Thrown if rows or cols are less than or equal to zero.</exception>
/// <remarks>
/// Uses a uniform distribution scaled by sqrt(6 / (rows + cols)) to initialize weights, promoting stable gradients.
/// Thread-safe random number generation with lock ensures consistency in parallel contexts.
/// Commonly used in deep learning to initialize weight matrices for layers.
/// </remarks>
public static Matrix<T> InitializeXavier(int rows, int cols)
{
if (rows <= 0) throw new ArgumentException("Number of rows must be positive.", nameof(rows));
if (cols <= 0) throw new ArgumentException("Number of columns must be positive.", nameof(cols));
double limit = Math.Sqrt(6.0 / (rows + cols));
var result = new Matrix<T>(rows, cols);
Parallel.For(0, rows, i =>
{
Random rand = _random.Value;
for (int j = 0; j < cols; j++)
{
double val = rand.NextDouble() * 2 * limit - limit;
result._data[i][j] = IsFloat ? (T)(object)(float)val : (T)(object)val;
}
});
return result;
}
/// <summary>
/// Initializes a matrix using He initialization for neural network weights, optimized for ReLU activations.
/// </summary>
/// <param name="rows">Number of rows in the matrix, typically the output size.</param>
/// <param name="cols">Number of columns in the matrix, typically the input size.</param>
/// <returns>A new Matrix<T> initialized with He values.</returns>
/// <exception cref="ArgumentException">Thrown if rows or cols are less than or equal to zero.</exception>
/// <remarks>
/// Uses a normal distribution with mean 0 and standard deviation sqrt(2 / cols) to initialize weights, promoting stable gradients.
/// Thread-safe random number generation with lock ensures consistency in parallel contexts.
/// Commonly used in deep learning to initialize weight matrices for layers with ReLU or similar activations.
/// </remarks>
public static Matrix<T> InitializeHe(int rows, int cols)
{
if (rows <= 0) throw new ArgumentException("Number of rows must be positive.", nameof(rows));
if (cols <= 0) throw new ArgumentException("Number of columns must be positive.", nameof(cols));
double stdDev = Math.Sqrt(2.0 / cols);
var result = new Matrix<T>(rows, cols);
Parallel.For(0, rows, i =>
{
Random rand = _random.Value;
for (int j = 0; j < cols; j++)
{
double u1 = 1.0 - rand.NextDouble();
double u2 = rand.NextDouble();
double normal = Math.Sqrt(-2.0 * Math.Log(u1)) * Math.Cos(2.0 * Math.PI * u2);
double val = normal * stdDev;
result._data[i][j] = IsFloat ? (T)(object)(float)val : (T)(object)val;
}
});
return result;
}
/// <summary>
/// Sets the values of a specific row in the matrix to the values of another matrix row.
/// </summary>
/// <param name="rowIndex">The index of the row to set, must be between 0 and Rows-1.</param>
/// <param name="rowValues">A Matrix<T> with 1 row and matching column count containing the values to set.</param>
/// <exception cref="IndexOutOfRangeException">Thrown if rowIndex is out of bounds.</exception>
/// <exception cref="ArgumentNullException">Thrown if rowValues is null.</exception>
/// <exception cref="ArgumentException">Thrown if rowValues does not have 1 row or its column count does not match this matrix's column count.</exception>
public void SetRow(int rowIndex, Matrix<T> rowValues)
{
if (rowIndex < 0 || rowIndex >= _rows)
throw new IndexOutOfRangeException($"Row index {rowIndex} is out of bounds for matrix with {Rows} rows.");
if (rowValues == null)
throw new ArgumentNullException(nameof(rowValues), "Row values matrix cannot be null.");
if (rowValues.Rows != 1)
throw new ArgumentException($"Row values must have exactly 1 row, got {rowValues.Rows}.");
if (rowValues.Columns != _cols)
throw new ArgumentException($"Row values columns ({rowValues.Columns}) must match matrix columns ({Columns}).");
lock (_syncLock)
Array.Copy(rowValues._data[0], _data[rowIndex], _cols);
}
/// <summary>
/// Adds a value to an existing element in the matrix, performing an in-place increment.
/// </summary>
/// <param name="row">The row index of the element to increment, must be between 0 and Rows-1.</param>
/// <param name="col">The column index of the element to increment, must be between 0 and Columns-1.</param>
/// <param name="value">The value to add to the existing element.</param>
/// <exception cref="IndexOutOfRangeException">Thrown if row or col is out of bounds.</exception>
public void AddInPlace(int row, int col, T value)
{
if (row < 0 || row >= _rows)
throw new IndexOutOfRangeException($"Row index {row} is out of bounds for matrix with {Rows} rows.");
if (col < 0 || col >= _cols)
throw new IndexOutOfRangeException($"Column index {col} is out of bounds for matrix with {Columns} columns.");
lock (_syncLock)
_data[row][col] = Add(_data[row][col], value);
}
/// <summary>
/// Gets the values of a specific row as an array.
/// </summary>
/// <param name="rowIndex">The index of the row to retrieve, must be between 0 and Rows-1.</param>
/// <returns>An array of type T containing the row values.</returns>
/// <exception cref="IndexOutOfRangeException">Thrown if rowIndex is out of bounds.</exception>
public T[] GetRowArray(int rowIndex)
{
if (rowIndex < 0 || rowIndex >= _rows)
throw new IndexOutOfRangeException($"Row index {rowIndex} is out of bounds for matrix with {Rows} rows.");
T[] result = new T[_cols];
Array.Copy(_data[rowIndex], result, _cols);
return result;
}
/// <summary>
/// Creates a deep copy of the matrix.
/// </summary>
/// <returns>A new Matrix<T> with identical values.</returns>
public Matrix<T> Copy()
{
var copy = new Matrix<T>(Rows, Columns);
Parallel.For(0, Rows, i => { for (int j = 0; j < Columns; j++) copy[i, j] = this[i, j]; });
return copy;
}
/// <summary>
/// Returns a new matrix representing the specified row as a 1xColumns matrix.
/// </summary>
/// <param name="rowIndex">The index of the row to extract, must be between 0 and Rows-1.</param>
/// <returns>A new Matrix<T> with shape [1, Columns] containing the row data.</returns>
/// <exception cref="IndexOutOfRangeException">Thrown if rowIndex is out of bounds.</exception>
public Matrix<T> Row(int rowIndex)
{
if (rowIndex < 0 || rowIndex >= _rows)
throw new IndexOutOfRangeException($"Row index {rowIndex} is out of bounds for matrix with {Rows} rows.");
var result = new Matrix<T>(1, _cols);
Array.Copy(_data[rowIndex], result._data[0], _cols);
return result;
}
/// <summary>
/// Returns a new matrix representing the specified column as a Rowsx1 matrix.
/// </summary>
/// <param name="colIndex">The index of the column to extract, must be between 0 and Columns-1.</param>
/// <returns>A new Matrix<T> with shape [Rows, 1] containing the column data.</returns>
/// <exception cref="IndexOutOfRangeException">Thrown if colIndex is out of bounds.</exception>
public Matrix<T> Column(int colIndex)
{
if (colIndex < 0 || colIndex >= _cols)
throw new IndexOutOfRangeException($"Column index {colIndex} is out of bounds for matrix with {Columns} columns.");
var result = new Matrix<T>(_rows, 1);
Parallel.For(0, _rows, i => result._data[i][0] = _data[i][colIndex]);
return result;
}
/// <summary>
/// Checks if all elements in the matrix satisfy the given predicate.
/// </summary>
/// <param name="predicate">The condition to test each element against.</param>
/// <returns>True if all elements satisfy the predicate, false otherwise.</returns>
public bool All(Func<T, bool> predicate)
{
bool allTrue = true;
Parallel.For(0, Rows, (i, state) =>
{
if (allTrue)
{
for (int j = 0; j < Columns; j++)
{
if (!predicate(_data[i][j]))
{
allTrue = false;
state.Stop(); // Early exit
break;
}
}
}
});
return allTrue;
}
/// <summary>
/// Checks if any element in the matrix satisfies the given predicate.
/// </summary>
/// <param name="predicate">The condition to test each element against.</param>
/// <returns>True if at least one element satisfies the predicate, false otherwise.</returns>
public bool Any(Func<T, bool> predicate)
{
bool found = false;
Parallel.For(0, Rows, (i, state) =>
{
if (!found) // Avoid unnecessary work after finding a match
{
for (int j = 0; j < Columns; j++)
{
if (predicate(_data[i][j]))
{
found = true;
state.Stop(); // Early exit
break;
}
}
}
});
return found;
}
/// <summary>
/// Creates a causal mask to prevent attending to future tokens in decoding.
/// </summary>
/// <param name="seqLen">Sequence length.</param>
/// <returns>Causal mask [seqLen, seqLen] with -infinity for future positions.</returns>
public static Matrix<double> CreateCausalMask(int seqLen)
{
var mask = new Matrix<double>(seqLen, seqLen);
Parallel.For(0, seqLen, i =>
{
for (int j = 0; j < seqLen; j++)
mask[i, j] = j > i ? -1e9 : 0.0;
});
return mask;
}
/// <summary>
/// Normalizes the matrix to have an L2 norm of 1, scaling all elements proportionally.
/// </summary>
/// <returns>A new matrix with the same shape, normalized to unit length.</returns>
/// <exception cref="InvalidOperationException">Thrown if the matrix is empty.</exception>
public Matrix<T> Normalize()
{
if (Rows == 0 || Columns == 0)
throw new InvalidOperationException("Cannot normalize an empty matrix.");
T sumOfSquares = this.ElementWiseMultiply(this).Sum();
if (double.IsNaN((double)(object)sumOfSquares))
throw new InvalidOperationException("Sum of squares is NaN in Normalize.");
T norm = Sqrt(sumOfSquares);
var result = new Matrix<T>(Rows, Columns);
T epsilon = IsFloat ? (T)(object)1e-6f : (T)(object)1e-6;
if (Compare(norm, epsilon) <= 0)
return result; // Return zeros if norm is too small
if (Rows * Columns < ParallelThreshold)
{
for (int i = 0; i < Rows; i++)
for (int j = 0; j < Columns; j++)
result._data[i][j] = Divide(_data[i][j], norm);
}
else
{
Parallel.For(0, Rows, i =>
{
for (int j = 0; j < Columns; j++)
result._data[i][j] = Divide(_data[i][j], norm);
});
}
return result;
}
/// <summary>
/// Computes the multiplicative inverse (1/x) of each element in the matrix.
/// </summary>
/// <returns>A new matrix with the same dimensions where each element is the inverse of the corresponding element.</returns>
/// <exception cref="ArgumentException">Thrown if any element is zero or near-zero.</exception>
public Matrix<T> ElementWiseInverse()
{
var result = new Matrix<T>(Rows, Columns);
T epsilon = IsFloat ? (T)(object)1e-6f : (T)(object)1e-6;
Parallel.For(0, Rows, i =>
{
for (int j = 0; j < Columns; j++)
{
T value = _data[i][j];
T absValue = IsFloat ? (T)(object)Math.Abs((float)(object)value) : (T)(object)Math.Abs((double)(object)value);
if (Compare(absValue, epsilon) < 0)
result[i, j] = Zero(); // Return 0 for zero or near-zero values
else
result[i, j] = Divide(One(), value);
}
});
return result;
}
/// <summary>
/// Computes the square root of each element in the matrix.
/// </summary>
/// <returns>A new matrix with the same dimensions where each element is the square root of the corresponding element.</returns>
/// <exception cref="ArgumentException">Thrown if any element is negative.</exception>
public Matrix<T> ElementWiseSqrt()
{
var result = new Matrix<T>(Rows, Columns);
Parallel.For(0, Rows, i =>
{
for (int j = 0; j < Columns; j++)
{
T value = _data[i][j];
if (Compare(value, Zero()) < 0)
throw new ArgumentException($"Cannot compute square root of negative value {value} at position [{i}, {j}].");
result[i, j] = Sqrt(value);
}
});
return result;
}
/// <summary>
/// Adds a bias vector (1 row) to each row of the matrix.
/// </summary>
/// <param name="bias">Matrix with 1 row and matching columns to apply as bias.</param>
/// <returns>A new matrix with bias added to each row.</returns>
/// <exception cref="ArgumentNullException">Thrown if bias is null.</exception>
/// <exception cref="ArgumentException">Thrown if bias dimensions are incompatible.</exception>
public Matrix<T> AddBias(Matrix<T> bias)
{
if (bias == null)
throw new ArgumentNullException(nameof(bias), "Bias matrix cannot be null.");
if (bias.Rows != 1)
throw new ArgumentException($"Bias must have 1 row, got {bias.Rows}.");
if (bias.Columns != this.Columns)
throw new ArgumentException($"Bias columns ({bias.Columns}) must match matrix columns ({this.Columns}).");
var result = new Matrix<T>(this.Rows, this.Columns);
Parallel.For(0, this.Rows, i =>
{
for (int j = 0; j < this.Columns; j++)
result._data[i][j] = Add(this._data[i][j], bias._data[0][j]);
});
return result;
}
/// <summary>
/// Applies the exponential function to each element of the matrix.
/// </summary>
/// <returns>A new matrix with exponential values.</returns>
public Matrix<T> Exp()
{
var result = new Matrix<T>(Rows, Columns);
if (Rows * Columns < ParallelThreshold)
{
for (int i = 0; i < Rows; i++)
for (int j = 0; j < Columns; j++)
result._data[i][j] = Exp(_data[i][j]);
}
else
{
Parallel.For(0, Rows, i =>
{
for (int j = 0; j < Columns; j++)
result._data[i][j] = Exp(_data[i][j]);
});
}
return result;
}
/// <summary>
/// Computes the sum of all elements in the matrix.
/// </summary>
/// <returns>The total sum as type T.</returns>
public T Sum()
{
T sum = Zero();
T compensation = Zero(); // Kahan compensation term
object lockObj = new object();
Parallel.For(0, Rows, () => (Zero(), Zero()), (i, state, local) =>
{
T localSum = local.Item1;
T localComp = local.Item2;
for (int j = 0; j < Columns; j++)
{
T y = Subtract(_data[i][j], localComp);
T t = Add(localSum, y);
localComp = Subtract(Subtract(t, localSum), y);
localSum = t;
}
return (localSum, localComp);
}, local =>
{
lock (lockObj)
{
T y = Subtract(local.Item1, compensation);
T t = Add(sum, y);
compensation = Subtract(Subtract(t, sum), y);
sum = t;
}
});
return sum;
}
/// <summary>
/// Sums all elements across rows, returning a matrix with 1 row and the same number of columns.
/// </summary>
/// <returns>A new Matrix<T> of shape [1, Columns] where each element is the sum of the corresponding column.</returns>
public Matrix<T> SumRows()
{
var result = new Matrix<T>(1, Columns);
Parallel.For(0, Columns, j =>
{
T sum = Zero();
for (int i = 0; i < Rows; i++)
sum = Add(sum, _data[i][j]);
result._data[0][j] = sum;
});
return result;
}
/// <summary>
/// Adds a scalar to every element of the matrix.
/// </summary>
/// <param name="scalar">Value to add to each element.</param>
/// <returns>A new matrix with the scalar added to all elements.</returns>
public Matrix<T> AddScalar(T scalar)
{
var result = new Matrix<T>(Rows, Columns);
if (Rows * Columns < ParallelThreshold)
{
for (int i = 0; i < Rows; i++)
for (int j = 0; j < Columns; j++)
result._data[i][j] = Add(_data[i][j], scalar);
}
else
{
Parallel.For(0, Rows, i =>
{
for (int j = 0; j < Columns; j++)
result._data[i][j] = Add(_data[i][j], scalar);
});
}
return result;
}
/// <summary>
/// Finds the index of the maximum value in each row.
/// </summary>
/// <returns>An array of integers, length equal to Rows, where each element is the column index of the maximum value in that row.</returns>
public int[] ArgMax()
{
int[] result = new int[Rows];
Parallel.For(0, Rows, i =>
{
int maxIdx = 0;
T maxVal = _data[i][0];
for (int j = 1; j < Columns; j++)
{
if (Compare(_data[i][j], maxVal) > 0)
{
maxVal = _data[i][j];
maxIdx = j;
}
}
result[i] = maxIdx;
});
return result;
}
/// <summary>
/// Performs element-wise multiplication of two matrices.
/// </summary>
/// <param name="other">Matrix to multiply element-wise with this matrix.</param>
/// <returns>A new matrix with element-wise products.</returns>
/// <exception cref="ArgumentNullException">Thrown if other is null.</exception>
/// <exception cref="ArgumentException">Thrown if sizes mismatch.</exception>
public Matrix<T> ElementWiseMultiply(Matrix<T> other)
{
CheckSizes(this, other, "element-wise multiply");
var result = new Matrix<T>(Rows, Columns);
if (Rows * Columns < ParallelThreshold)
{
for (int i = 0; i < Rows; i++)
for (int j = 0; j < Columns; j++)
result._data[i][j] = Multiply(_data[i][j], other._data[i][j]);
}
else
{
Parallel.For(0, Rows, i =>
{
for (int j = 0; j < Columns; j++)
result._data[i][j] = Multiply(_data[i][j], other._data[i][j]);
});
}
return result;
}
/// <summary>
/// Transposes the matrix, swapping rows and columns.
/// </summary>
/// <returns>A new matrix with dimensions [Columns, Rows].</returns>
/// <remarks>
/// Parallelized over rows of the original matrix to efficiently build the transposed matrix.
/// Restored from original Matrix<T> implementation.
/// </remarks>
public Matrix<T> Transpose()
{
var result = new Matrix<T>(Columns, Rows);
Parallel.For(0, Rows, i =>
{
for (int j = 0; j < Columns; j++)
result._data[j][i] = _data[i][j];
});
return result;
}
/// <summary>
/// Extracts a submatrix from the current matrix starting at the specified row and column, with the given size.
/// </summary>
/// <param name="startRow">The starting row index (inclusive) of the submatrix, must be between 0 and Rows-1.</param>
/// <param name="rowCount">The number of rows to include in the submatrix, must be positive and within bounds.</param>
/// <param name="startCol">The starting column index (inclusive) of the submatrix, must be between 0 and Columns-1.</param>
/// <param name="colCount">The number of columns to include in the submatrix, must be positive and within bounds.</param>
/// <returns>A new Matrix<T> containing the specified submatrix [rowCount, colCount].</returns>
/// <exception cref="ArgumentOutOfRangeException">Thrown if the specified indices or counts are invalid or exceed matrix dimensions.</exception>
/// <remarks>
/// Parallelized for efficient extraction, used in multi-head attention to split matrices into head-specific portions.
/// Restored from original Matrix<T> implementation.
/// </remarks>
public Matrix<T> SubMatrix(int startRow, int rowCount, int startCol, int colCount)
{
if (startRow < 0 || startRow >= Rows)
throw new ArgumentOutOfRangeException(nameof(startRow), $"Start row {startRow} must be between 0 and {Rows - 1}.");
if (startCol < 0 || startCol >= Columns)
throw new ArgumentOutOfRangeException(nameof(startCol), $"Start column {startCol} must be between 0 and {Columns - 1}.");
if (rowCount <= 0)
throw new ArgumentOutOfRangeException(nameof(rowCount), $"Row count {rowCount} must be positive.");
if (colCount <= 0)
throw new ArgumentOutOfRangeException(nameof(colCount), $"Column count {colCount} must be positive.");
if (startRow + rowCount > Rows)
throw new ArgumentOutOfRangeException(nameof(rowCount), $"Row range {startRow} to {startRow + rowCount - 1} exceeds matrix rows {Rows}.");
if (startCol + colCount > Columns)
throw new ArgumentOutOfRangeException(nameof(colCount), $"Column range {startCol} to {startCol + colCount - 1} exceeds matrix columns {Columns}.");
var result = new Matrix<T>(rowCount, colCount);
Parallel.For(0, rowCount, i =>
{
for (int j = 0; j < colCount; j++)
result[i, j] = this[startRow + i, startCol + j];
});
return result;
}
/// <summary>
/// Adds two matrices element-wise, or adds a bias vector to each row if applicable.
/// </summary>
/// <param name="a">First matrix or minuend.</param>
/// <param name="b">Second matrix, bias vector, or subtrahend.</param>
/// <returns>A new matrix with the sum.</returns>
/// <exception cref="ArgumentNullException">Thrown if either matrix is null.</exception>
/// <exception cref="ArgumentException">Thrown if sizes mismatch for element-wise addition.</exception>
public static Matrix<T> operator +(Matrix<T> a, Matrix<T> b)
{
if (b.Rows == 1 && b.Columns == a.Columns)
return a.AddBias(b);
CheckSizes(a, b, "add");
var result = new Matrix<T>(a.Rows, a.Columns);
Parallel.For(0, a.Rows, i =>
{
for (int j = 0; j < a.Columns; j++)
result._data[i][j] = Add(a._data[i][j], b._data[i][j]);
});
return result;
}
/// <summary>
/// Subtracts one matrix from another element-wise.
/// </summary>
/// <param name="a">Minuend matrix.</param>
/// <param name="b">Subtrahend matrix.</param>
/// <returns>A new matrix with the difference.</returns>
/// <exception cref="ArgumentNullException">Thrown if either matrix is null.</exception>
/// <exception cref="ArgumentException">Thrown if sizes mismatch.</exception>
public static Matrix<T> operator -(Matrix<T> a, Matrix<T> b)
{
CheckSizes(a, b, "subtract");
var result = new Matrix<T>(a.Rows, a.Columns);
Parallel.For(0, a.Rows, i =>
{
for (int j = 0; j < a.Columns; j++)
result._data[i][j] = Subtract(a._data[i][j], b._data[i][j]);
});
return result;
}
/// <summary>
/// Multiplies the matrix by a scalar (left operand).
/// </summary>
/// <param name="scalar">Scalar value to multiply each element by.</param>
/// <param name="m">Matrix to scale.</param>
/// <returns>A new scaled matrix.</returns>
public static Matrix<T> operator *(T scalar, Matrix<T> m)
{
var result = new Matrix<T>(m.Rows, m.Columns);
Parallel.For(0, m.Rows, i =>
{
for (int j = 0; j < m.Columns; j++)
result._data[i][j] = Multiply(scalar, m._data[i][j]);
});
return result;
}
/// <summary>
/// Multiplies the matrix by a scalar (right operand).
/// </summary>
/// <param name="m">Matrix to scale.</param>
/// <param name="scalar">Scalar value to multiply each element by.</param>
/// <returns>A new scaled matrix.</returns>
public static Matrix<T> operator *(Matrix<T> m, T scalar) => scalar * m;
/// <summary>
/// Performs matrix multiplication between two matrices.
/// </summary>
/// <param name="a">Left matrix (minuend).</param>
/// <param name="b">Right matrix (subtrahend).</param>
/// <returns>A new matrix, the product of a and b.</returns>
/// <exception cref="ArgumentException">Thrown if a.Columns does not match b.Rows.</exception>
public static Matrix<T> operator *(Matrix<T> a, Matrix<T> b)
{
if (a == null || b == null)
throw new ArgumentNullException("Matrices cannot be null.");
if (a.Columns != b.Rows)
throw new ArgumentException($"Matrix multiplication requires {a.Columns} columns to match {b.Rows} rows.");
var result = new Matrix<T>(a.Rows, b.Columns);
int totalOps = a.Rows * a.Columns * b.Columns;
int numThreads = Environment.ProcessorCount;
int minOpsPerThread = 1000; // Tune based on profiling
// Dynamic threshold: Use parallelism if enough work per thread
if (totalOps < numThreads * minOpsPerThread)
{
// Sequential for very small matrices
for (int i = 0; i < a.Rows; i++)
{
for (int j = 0; j < b.Columns; j++)
{
T sum = Zero();
for (int k = 0; k < a.Columns; k++)
sum = Add(sum, Multiply(a._data[i][k], b._data[k][j]));
result._data[i][j] = sum;
}
}
}
else
{
// Adaptive block size based on cache (e.g., L1 cache ~32KB, 4096 doubles)
const int cacheSizeElements = 4096;
int blockSize = Math.Min(Math.Max(16, cacheSizeElements / Math.Max(b.Columns, a.Columns)), a.Rows);
int rowBlocks = (a.Rows + blockSize - 1) / blockSize;
Parallel.For(0, rowBlocks, new ParallelOptions { MaxDegreeOfParallelism = numThreads }, ib =>
{
int startRow = ib * blockSize;
int endRow = Math.Min(startRow + blockSize, a.Rows);
// Transpose b locally for better cache access (optional, costs memory)
T[][] bT = b.Transpose()._data; // [b.Columns, b.Rows]
for (int i = startRow; i < endRow; i++)
{
for (int j = 0; j < b.Columns; j++)
{
T sum = Zero();
T compensation = Zero(); // Kahan summation
for (int k = 0; k < a.Columns; k++)
{
T product = Multiply(a._data[i][k], bT[j][k]); // Use transposed b
T y = Subtract(product, compensation);
T t = Add(sum, y);
compensation = Subtract(Subtract(t, sum), y);
sum = t;
}
result._data[i][j] = sum;
}
}
});
}
return result;
}
/// <summary>
/// Creates a matrix filled with ones of the specified dimensions.
/// </summary>
/// <param name="rows">The number of rows in the matrix.</param>
/// <param name="cols">The number of columns in the matrix.</param>
/// <returns>A new Matrix<T> of shape [rows, cols] where all elements are 1.</returns>
/// <exception cref="ArgumentException">Thrown if rows or cols are less than or equal to zero.</exception>
public static Matrix<T> Ones(int rows, int cols)
{
if (rows <= 0)
throw new ArgumentException("Number of rows must be positive.", nameof(rows));
if (cols <= 0)
throw new ArgumentException("Number of columns must be positive.", nameof(cols));
var matrix = new Matrix<T>(rows, cols);
Parallel.For(0, rows, i =>
{
for (int j = 0; j < cols; j++)
matrix[i, j] = One();
});
return matrix;
}
/// <summary>
/// Creates a matrix filled with zeros of the specified dimensions.
/// </summary>
/// <param name="rows">The number of rows in the matrix.</param>
/// <param name="cols">The number of columns in the matrix.</param>
/// <returns>A new Matrix<T> of shape [rows, cols] where all elements are 0.</returns>
/// <exception cref="ArgumentException">Thrown if rows or cols are less than or equal to zero.</exception>
public static Matrix<T> Zeros(int rows, int cols)
{
if (rows <= 0)
throw new ArgumentException("Number of rows must be positive.", nameof(rows));
if (cols <= 0)
throw new ArgumentException("Number of columns must be positive.", nameof(cols));
return new Matrix<T>(rows, cols); // Default initialization to zero
}
/// <summary>
/// Creates a matrix with random values between -maxValue and +maxValue using uniform distribution.
/// </summary>
/// <param name="rows">Number of rows in the resulting matrix.</param>
/// <param name="cols">Number of columns in the resulting matrix.</param>
/// <param name="maxValue">Maximum absolute value for random entries.</param>
/// <returns>A new matrix filled with random values.</returns>
public static Matrix<T> Random(int rows, int cols, T maxValue, int? seed = null)
{
var rand = seed.HasValue ? new Random(seed.Value) : _random.Value;
var result = new Matrix<T>(rows, cols);
Parallel.For(0, rows, i =>
{
Random threadRand = seed.HasValue ? new Random(rand.Next()) : _random.Value;
for (int j = 0; j < cols; j++)
{
T val = Multiply((T)(object)(2.0 * threadRand.NextDouble() - 1.0), maxValue);
result._data[i][j] = val;
}
});
return result;
}
/// <summary>
/// Creates an identity matrix with ones on the diagonal and zeros elsewhere.
/// </summary>
/// <param name="size">Size of the square matrix (rows = columns).</param>
/// <returns>A new identity matrix of shape [size, size].</returns>
public static Matrix<T> Identity(int size)
{
var result = new Matrix<T>(size, size);
Parallel.For(0, size, i => result._data[i][i] = One());
return result;
}
/// <summary>
/// Validates that two matrices have the same dimensions for element-wise operations.
/// </summary>
/// <param name="a">First matrix to check.</param>
/// <param name="b">Second matrix to check.</param>
/// <param name="operation">Description of the operation for error messages.</param>
/// <exception cref="ArgumentNullException">Thrown if either matrix is null.</exception>
/// <exception cref="ArgumentException">Thrown if dimensions mismatch.</exception>
private static void CheckSizes(Matrix<T> a, Matrix<T> b, string operation)
{
if (a == null || b == null)
throw new ArgumentNullException($"Both matrices must be non-null for {operation}.");
if (a.Rows != b.Rows || a.Columns != b.Columns)
throw new ArgumentException($"For {operation}, matrices must have same dimensions, got [{a.Rows}, {a.Columns}] vs [{b.Rows}, {b.Columns}].");
}
/// <summary>
/// Creates a deep copy of the matrix.
/// </summary>
/// <returns>A new matrix with identical values.</returns>
public Matrix<T> Clone()
{
var result = new Matrix<T>(Rows, Columns);
Parallel.For(0, Rows, i => Array.Copy(_data[i], result._data[i], Columns));
return result;
}
/// <summary>
/// Returns a string representation of the matrix.
/// </summary>
/// <returns>A formatted string showing the matrix rows, each enclosed in brackets and comma-separated.</returns>
public override string ToString()
{
return string.Join("\n", _data.Select(row => $"[{string.Join(", ", row)}]"));
}
/// <summary>
/// Checks if another object is an identical matrix.
/// </summary>
/// <param name="obj">Object to compare with this matrix.</param>
/// <returns>True if the matrices are equal in size and element values, false otherwise.</returns>
public override bool Equals(object obj)
{
if (obj is not Matrix<T> other || Rows != other.Rows || Columns != other.Columns)
return false;
for (int i = 0; i < Rows; i++)
for (int j = 0; j < Columns; j++)
if (!Equals(_data[i][j], other._data[i][j]))
return false;
return true;
}
/// <summary>
/// Generates a hash code for the matrix.
/// </summary>
/// <returns>A hash code based on all elements, suitable for use in hash-based collections.</returns>
public override int GetHashCode()
{
int hash = 17;
for (int i = 0; i < Rows; i++)
for (int j = 0; j < Columns; j++)
hash = hash * 23 + _data[i][j].GetHashCode();
return hash;
}
/// <summary>
/// Adds two values of type T.
/// </summary>
/// <param name="a">First operand.</param>
/// <param name="b">Second operand.</param>
/// <returns>The sum of a and b as type T.</returns>
public static T Add(T a, T b) => IsFloat ? (T)(object)((float)(object)a + (float)(object)b) : (T)(object)((double)(object)a + (double)(object)b);
/// <summary>
/// Subtracts one value from another.
/// </summary>
/// <param name="a">Minuend.</param>
/// <param name="b">Subtrahend.</param>
/// <returns>The difference a - b as type T.</returns>
public static T Subtract(T a, T b) => IsFloat ? (T)(object)((float)(object)a - (float)(object)b) : (T)(object)((double)(object)a - (double)(object)b);
/// <summary>
/// Multiplies two values of type T.
/// </summary>
/// <param name="a">First operand.</param>
/// <param name="b">Second operand.</param>
/// <returns>The product of a and b as type T.</returns>
public static T Multiply(T a, T b) => IsFloat ? (T)(object)((float)(object)a * (float)(object)b) : (T)(object)((double)(object)a * (double)(object)b);
/// <summary>
/// Divides one value by another, with epsilon to prevent division by zero.
/// </summary>
/// <param name="a">Numerator.</param>
/// <param name="b">Denominator.</param>
/// <returns>The quotient a / b as type T, with a small epsilon if b is zero.</returns>
public static T Divide(T a, T b)
{
T epsilon = IsFloat ? (T)(object)1e-6f : (T)(object)1e-6;
T absB = IsFloat ? (T)(object)Math.Abs((float)(object)b) : (T)(object)Math.Abs((double)(object)b);
T safeB = Compare(absB, epsilon) < 0 ? (Compare(b, Zero()) < 0 ? Subtract(Zero(), epsilon) : epsilon) : b;
return IsFloat ? (T)(object)((float)(object)a / (float)(object)safeB)
: (T)(object)((double)(object)a / (double)(object)safeB);
}
/// <summary>
/// Computes the exponential of a value.
/// </summary>
/// <param name="x">The exponent.</param>
/// <returns>e^x as type T.</returns>
public static T Exp(T x) => IsFloat ? (T)(object)(float)Math.Exp((float)(object)x) : (T)(object)Math.Exp((double)(object)x);
/// <summary>
/// Computes the square root of a value.
/// </summary>
/// <param name="x">Non-negative value to compute the square root of.</param>
/// <returns>The square root of x as type T.</returns>
/// <exception cref="ArgumentException">Thrown if x is negative.</exception>
public static T Sqrt(T x)
{
if (Compare(x, Zero()) < 0)
throw new ArgumentException("Cannot compute square root of a negative number.");
return IsFloat ? (T)(object)(float)Math.Sqrt((float)(object)x) : (T)(object)Math.Sqrt((double)(object)x);
}
/// <summary>
/// Compares two values of type T.
/// </summary>
/// <param name="a">First value.</param>
/// <param name="b">Second value.</param>
/// <returns>-1 if a < b, 0 if a == b, 1 if a > b.</returns>
public static int Compare(T a, T b) => IsFloat ? ((float)(object)a).CompareTo((float)(object)b) : ((double)(object)a).CompareTo((double)(object)b);
/// <summary>
/// Returns zero for type T.
/// </summary>
/// <returns>0f for float, 0.0 for double.</returns>
public static T Zero() => IsFloat ? (T)(object)0f : (T)(object)0.0;
/// <summary>
/// Returns one for type T.
/// </summary>
/// <returns>1f for float, 1.0 for double.</returns>
public static T One() => IsFloat ? (T)(object)1f : (T)(object)1.0;
}
/// <summary>
/// Represents a vector, a specialized matrix with either one row (row vector) or one column (column vector).
/// Inherits from Matrix<T> to leverage its functionality while providing vector-specific operations.
/// </summary>
/// <typeparam name="T">The numeric type, constrained to float or double, implementing IComparable and IEquatable.</typeparam>
public class Vector<T> : Matrix<T> where T : struct, IComparable<T>, IEquatable<T>
{
#region Fields:
/// <summary>
/// Indicates whether the vector is a row vector (1 row, n columns) or column vector (n rows, 1 column).
/// </summary>
private readonly bool _isRowVector;
/// <summary>
/// Gets the length of the vector (number of elements).
/// </summary>
public int Length => _isRowVector ? Columns : Rows;
#endregion
#region Properties:
/// <summary>
/// Gets or sets the value at the specified index in the vector.
/// </summary>
/// <param name="index">Index of the element, must be between 0 and Length-1.</param>
/// <returns>The value at the specified index.</returns>
/// <exception cref="IndexOutOfRangeException">Thrown if index is out of bounds.</exception>
public T this[int index]
{
get => _isRowVector ? this[0, index] : this[index, 0];
set
{
if (_isRowVector) this[0, index] = value;
else this[index, 0] = value;
}
}
#endregion
/// <summary>
/// Initializes a vector with the specified length, treated as a column vector by default.
/// </summary>
/// <param name="length">Number of elements, must be positive.</param>
/// <param name="isRowVector">If true, creates a row vector [1, length]; otherwise, a column vector [length, 1].</param>
/// <exception cref="ArgumentException">Thrown if length is less than or equal to zero.</exception>
public Vector(int length, bool isRowVector = false) : base(isRowVector ? 1 : length, isRowVector ? length : 1)
{
if (length <= 0)
throw new ArgumentException("Vector length must be positive.", nameof(length));
_isRowVector = isRowVector;
}
/// <summary>
/// Initializes a vector from an array of values, treated as a column vector by default.
/// </summary>
/// <param name="data">Array of values, must be non-null and non-empty.</param>
/// <param name="isRowVector">If true, creates a row vector [1, n]; otherwise, a column vector [n, 1].</param>
/// <exception cref="ArgumentNullException">Thrown if data is null.</exception>
/// <exception cref="ArgumentException">Thrown if data is empty.</exception>
public Vector(T[] data, bool isRowVector = false) : base(isRowVector ? new T[][] { data } : data.Select(x => new T[] { x }).ToArray())
{
if (data == null)
throw new ArgumentNullException(nameof(data), "Vector data cannot be null.");
if (data.Length == 0)
throw new ArgumentException("Vector data must not be empty.", nameof(data));
_isRowVector = isRowVector;
}
/// <summary>
/// Computes the dot product of this vector with another vector.
/// </summary>
/// <param name="other">The other vector, must have the same length.</param>
/// <returns>The dot product as type T.</returns>
/// <exception cref="ArgumentNullException">Thrown if other is null.</exception>
/// <exception cref="ArgumentException">Thrown if vectors have different lengths or incompatible orientations.</exception>
public T Dot(Vector<T> other)
{
if (other == null)
throw new ArgumentNullException(nameof(other));
if (Length != other.Length)
throw new ArgumentException($"Vectors must have same length, got {Length} vs {other.Length}.");
if (_isRowVector && !other._isRowVector && Columns == other.Rows)
return (this * other)[0, 0];
T sum = Zero();
T comp = Zero();
object lockObj = new object();
Parallel.For(0, Length, () => (Zero(), Zero()), (i, state, local) =>
{
T localSum = local.Item1;
T localComp = local.Item2;
T y = Subtract(Multiply(this[i], other[i]), localComp);
T t = Add(localSum, y);
localComp = Subtract(Subtract(t, localSum), y);
localSum = t;
return (localSum, localComp);
}, local =>
{
lock (lockObj)
{
T y = Subtract(local.Item1, comp);
T t = Add(sum, y);
comp = Subtract(Subtract(t, sum), y);
sum = t;
}
});
return sum;
}
/// <summary>
/// Computes the magnitude (L2 norm) of the vector.
/// </summary>
/// <returns>The magnitude as type T.</returns>
public T Magnitude()
{
return Sqrt(this.ElementWiseMultiply(this).Sum());
}
/// <summary>
/// Normalizes the vector to have a magnitude of 1.
/// </summary>
/// <returns>A new Vector<T> with the same direction but unit length.</returns>
/// <exception cref="InvalidOperationException">Thrown if the magnitude is zero.</exception>
public Vector<T> Normalize()
{
var matrix = this.Normalize();
return new Vector<T>(matrix.GetRowArray(0), _isRowVector);
}
/// <summary>
/// Converts the vector to an array.
/// </summary>
/// <returns>An array of type T containing the vector elements.</returns>
public T[] ToArray()
{
return _isRowVector ? GetRowArray(0) : Enumerable.Range(0, Rows).Select(i => this[i, 0]).ToArray();
}
/// <summary>
/// Converts the vector to a row vector if it isn’t already.
/// </summary>
/// <returns>A new Vector<T> as a row vector [1, Length].</returns>
public Vector<T> ToRowVector()
{
if (_isRowVector) return this;
return new Vector<T>(Enumerable.Range(0, Rows).Select(i => this[i, 0]).ToArray(), true);
}
/// <summary>
/// Converts the vector to a column vector if it isn’t already.
/// </summary>
/// <returns>A new Vector<T> as a column vector [Length, 1].</returns>
public Vector<T> ToColumnVector()
{
if (!_isRowVector) return this;
return new Vector<T>(GetRowArray(0), false);
}
}
}
Here it a basic Transformer Model:
using System;
using System.Linq;
namespace TransformerModel
{
/// <summary>
/// Represents a simplified Transformer model with multi-head self-attention, suitable for small-scale training tasks.
/// </summary>
/// <remarks>
/// This implementation includes the core components of a Transformer: multi-head attention, feed-forward layers,
/// layer normalization, and positional encoding. It is designed for educational purposes or small models, not
/// optimized for production-scale tasks like large language models.
/// </remarks>
public class Transformer
{
private readonly int _dModel; // Dimension of the model (embedding size)
private readonly int _numHeads; // Number of attention heads
private readonly int _dFF; // Dimension of the feed-forward network
private readonly int _maxSeqLength; // Maximum sequence length for positional encoding
private readonly double[][] _posEncoding; // Precomputed positional encodings
private readonly MultiHeadAttention _attention; // Attention mechanism
private readonly FeedForwardNetwork _ffn; // Feed-forward network
private readonly LayerNormalization _norm1; // First layer normalization
private readonly LayerNormalization _norm2; // Second layer normalization
private readonly Random _rand; // Random number generator for initialization
/// <summary>
/// Initializes a new instance of the <see cref="Transformer"/> class.
/// </summary>
/// <param name="dModel">The dimension of the input/output embeddings (must be divisible by numHeads).</param>
/// <param name="numHeads">The number of attention heads.</param>
/// <param name="dFF">The dimension of the feed-forward network hidden layer.</param>
/// <param name="maxSeqLength">The maximum sequence length for positional encoding.</param>
/// <exception cref="ArgumentException">Thrown if parameters are invalid.</exception>
public Transformer(int dModel, int numHeads, int dFF, int maxSeqLength)
{
if (dModel <= 0 || numHeads <= 0 || dFF <= 0 || maxSeqLength <= 0)
throw new ArgumentException("All dimensions and lengths must be positive.");
if (dModel % numHeads != 0)
throw new ArgumentException("dModel must be divisible by numHeads.");
_dModel = dModel;
_numHeads = numHeads;
_dFF = dFF;
_maxSeqLength = maxSeqLength;
_rand = new Random();
_posEncoding = GeneratePositionalEncoding(dModel, maxSeqLength);
_attention = new MultiHeadAttention(dModel, numHeads);
_ffn = new FeedForwardNetwork(dModel, dFF);
_norm1 = new LayerNormalization(dModel);
_norm2 = new LayerNormalization(dModel);
}
/// <summary>
/// Performs a forward pass through the Transformer layer.
/// </summary>
/// <param name="input">The input matrix [sequenceLength, dModel].</param>
/// <returns>The output matrix after applying attention and feed-forward layers.</returns>
/// <exception cref="ArgumentNullException">Thrown if <paramref name="input"/> is null.</exception>
/// <exception cref="ArgumentException">Thrown if input dimensions are invalid.</exception>
public double[][] Forward(double[][] input)
{
if (input == null)
throw new ArgumentNullException(nameof(input), "Input cannot be null.");
if (input.Length > _maxSeqLength || input.Any(row => row.Length != _dModel))
throw new ArgumentException("Input dimensions must match [<=maxSeqLength, dModel].");
// Add positional encoding
double[][] x = AddPositionalEncoding(input);
// Multi-head attention + residual connection + normalization
double[][] attnOutput = _attention.Forward(x, x, x); // Q, K, V are the same for self-attention
x = AddResidual(x, attnOutput);
x = _norm1.Normalize(x);
// Feed-forward network + residual connection + normalization
double[][] ffnOutput = _ffn.Forward(x);
x = AddResidual(x, ffnOutput);
x = _norm2.Normalize(x);
return x;
}
/// <summary>
/// Adds positional encoding to the input sequence.
/// </summary>
private double[][] AddPositionalEncoding(double[][] input)
{
double[][] result = new double[input.Length][];
for (int i = 0; i < input.Length; i++)
{
result[i] = new double[_dModel];
for (int j = 0; j < _dModel; j++)
{
result[i][j] = input[i][j] + _posEncoding[i][j];
}
}
return result;
}
/// <summary>
/// Adds a residual connection (element-wise addition).
/// </summary>
private double[][] AddResidual(double[][] x, double[][] residual)
{
double[][] result = new double[x.Length][];
for (int i = 0; i < x.Length; i++)
{
result[i] = new double[_dModel];
for (int j = 0; j < _dModel; j++)
{
result[i][j] = x[i][j] + residual[i][j];
}
}
return result;
}
/// <summary>
/// Generates positional encoding as per the Transformer paper.
/// </summary>
private double[][] GeneratePositionalEncoding(int dModel, int maxSeqLength)
{
double[][] encoding = new double[maxSeqLength][];
for (int pos = 0; pos < maxSeqLength; pos++)
{
encoding[pos] = new double[dModel];
for (int i = 0; i < dModel; i++)
{
double angle = pos / Math.Pow(10000, 2.0 * (i / 2) / dModel);
encoding[pos][i] = i % 2 == 0 ? Math.Sin(angle) : Math.Cos(angle);
}
}
return encoding;
}
}
/// <summary>
/// Implements multi-head self-attention mechanism.
/// </summary>
internal class MultiHeadAttention
{
private readonly int _dModel;
private readonly int _numHeads;
private readonly int _dK; // Dimension per head
private readonly double[][][] _weights; // Q, K, V weights [numHeads, dModel, dK]
public MultiHeadAttention(int dModel, int numHeads)
{
_dModel = dModel;
_numHeads = numHeads;
_dK = dModel / numHeads;
_weights = new double[3][][];
Random rand = new Random();
for (int i = 0; i < 3; i++) // Q, K, V
{
_weights[i] = new double[dModel][];
for (int j = 0; j < dModel; j++)
{
_weights[i][j] = new double[_dK];
for (int k = 0; k < _dK; k++)
_weights[i][j][k] = rand.NextDouble() * 0.1 - 0.05; // Small random init
}
}
}
/// <summary>
/// Computes the multi-head attention output.
/// </summary>
public double[][] Forward(double[][] query, double[][] key, double[][] value)
{
double[][][] heads = new double[_numHeads][][];
for (int h = 0; h < _numHeads; h++)
{
heads[h] = ScaledDotProductAttention(
MatMul(query, _weights[0]), // Q
MatMul(key, _weights[1]), // K
MatMul(value, _weights[2]) // V
);
}
return ConcatenateHeads(heads);
}
private double[][] ScaledDotProductAttention(double[][] q, double[][] k, double[][] v)
{
double[][] scores = MatMul(q, Transpose(k));
for (int i = 0; i < scores.Length; i++)
for (int j = 0; j < scores[i].Length; j++)
scores[i][j] /= Math.Sqrt(_dK);
double[][] attnWeights = Softmax(scores);
return MatMul(attnWeights, v);
}
private double[][] ConcatenateHeads(double[][][] heads)
{
// Simplified concatenation for small models
double[][] result = new double[heads[0].Length][];
for (int i = 0; i < result.Length; i++)
{
result[i] = heads.SelectMany(h => h[i]).ToArray();
}
return result;
}
}
/// <summary>
/// Implements a feed-forward neural network for the Transformer.
/// </summary>
internal class FeedForwardNetwork
{
private readonly double[][] _w1; // First layer weights [dModel, dFF]
private readonly double[] _b1; // First layer biases
private readonly double[][] _w2; // Second layer weights [dFF, dModel]
private readonly double[] _b2; // Second layer biases
public FeedForwardNetwork(int dModel, int dFF)
{
Random rand = new Random();
_w1 = RandomMatrix(dModel, dFF, rand);
_b1 = RandomVector(dFF, rand);
_w2 = RandomMatrix(dFF, dModel, rand);
_b2 = RandomVector(dModel, rand);
}
public double[][] Forward(double[][] input)
{
double[][] hidden = ReLU(MatMul(input, _w1).Select((row, i) => Add(row, _b1)).ToArray());
return MatMul(hidden, _w2).Select((row, i) => Add(row, _b2)).ToArray();
}
}
/// <summary>
/// Implements layer normalization.
/// </summary>
internal class LayerNormalization
{
private readonly int _dModel;
private readonly double[] _gamma;
private readonly double[] _beta;
public LayerNormalization(int dModel)
{
_dModel = dModel;
_gamma = Enumerable.Repeat(1.0, dModel).ToArray();
_beta = new double[dModel];
}
public double[][] Normalize(double[][] input)
{
double[][] result = new double[input.Length][];
for (int i = 0; i < input.Length; i++)
{
double mean = input[i].Average();
double variance = input[i].Select(x => Math.Pow(x - mean, 2)).Average();
result[i] = input[i].Select((x, j) => _gamma[j] * (x - mean) / Math.Sqrt(variance + 1e-6) + _beta[j]).ToArray();
}
return result;
}
}
#region Utility Methods
private static double[][] MatMul(double[][] a, double[][] b) =>
a.Select(row => b[0].Select((_, j) => row.Select((x, k) => x * b[k][j]).Sum()).ToArray()).ToArray();
private static double[][] Transpose(double[][] m) =>
Enumerable.Range(0, m[0].Length).Select(j => m.Select(row => row[j]).ToArray()).ToArray();
private static double[][] Softmax(double[][] x)
{
double[][] result = new double[x.Length][];
for (int i = 0; i < x.Length; i++)
{
double max = x[i].Max();
double sum = x[i].Sum(v => Math.Exp(v - max));
result[i] = x[i].Select(v => Math.Exp(v - max) / sum).ToArray();
}
return result;
}
private static double[][] ReLU(double[][] x) =>
x.Select(row => row.Select(v => Math.Max(0, v)).ToArray()).ToArray();
private static double[] Add(double[] a, double[] b) =>
a.Select((v, i) => v + b[i]).ToArray();
private static double[][] RandomMatrix(int rows, int cols, Random rand) =>
Enumerable.Range(0, rows).Select(_ => RandomVector(cols, rand)).ToArray();
private static double[] RandomVector(int size, Random rand) =>
Enumerable.Range(0, size).Select(_ => rand.NextDouble() * 0.1 - 0.05).ToArray();
#endregion
}
Much more can be done, to make this better, but for the sake of simplicity, and to give you just a simple approach, this is a pretty solid start.
If one wanted to train the model:
using System;
using System.Linq;
namespace TransformerModel
{
/// <summary>
/// Handles the training of a <see cref="Transformer"/> model using gradient descent.
/// </summary>
/// <remarks>
/// This class implements a basic training pipeline with mean squared error (MSE) loss and gradient descent optimization.
/// It is designed for small-scale models and assumes the Transformer is used for regression-like tasks (e.g., predicting
/// a sequence of vectors). For classification or language modeling, you’d need to adjust the loss function.
/// </remarks>
public class TransformerTrainer
{
private readonly Transformer _transformer;
private readonly double _learningRate;
private readonly int _epochs;
/// <summary>
/// Initializes a new instance of the <see cref="TransformerTrainer"/> class.
/// </summary>
/// <param name="transformer">The Transformer model to train.</param>
/// <param name="learningRate">The learning rate for gradient descent (default: 0.001).</param>
/// <param name="epochs">The number of training epochs (default: 100).</param>
/// <exception cref="ArgumentNullException">Thrown if <paramref name="transformer"/> is null.</exception>
/// <exception cref="ArgumentException">Thrown if <paramref name="learningRate"/> or <paramref name="epochs"/> are invalid.</exception>
public TransformerTrainer(Transformer transformer, double learningRate = 0.001, int epochs = 100)
{
if (transformer == null)
throw new ArgumentNullException(nameof(transformer), "Transformer cannot be null.");
if (learningRate <= 0)
throw new ArgumentException("Learning rate must be positive.", nameof(learningRate));
if (epochs <= 0)
throw new ArgumentException("Number of epochs must be positive.", nameof(epochs));
_transformer = transformer;
_learningRate = learningRate;
_epochs = epochs;
}
/// <summary>
/// Trains the Transformer model on the given input and target data.
/// </summary>
/// <param name="inputs">Array of input sequences [batchSize, seqLength, dModel].</param>
/// <param name="targets">Array of target sequences [batchSize, seqLength, dModel].</param>
/// <exception cref="ArgumentNullException">Thrown if <paramref name="inputs"/> or <paramref name="targets"/> are null.</exception>
/// <exception cref="ArgumentException">Thrown if input and target dimensions don’t match.</exception>
public void Train(double[][][] inputs, double[][][] targets)
{
if (inputs == null || targets == null)
throw new ArgumentNullException(inputs == null ? nameof(inputs) : nameof(targets), "Inputs and targets cannot be null.");
if (inputs.Length != targets.Length || !inputs.Zip(targets, (i, t) => i.Length == t.Length && i[0].Length == t[0].Length).All(x => x))
throw new ArgumentException("Inputs and targets must have matching dimensions [batchSize, seqLength, dModel].");
for (int epoch = 0; epoch < _epochs; epoch++)
{
double totalLoss = 0;
for (int i = 0; i < inputs.Length; i++)
{
// Forward pass
double[][] output = _transformer.Forward(inputs[i]);
totalLoss += ComputeLoss(output, targets[i]);
// Backward pass (compute gradients)
double[][] outputGrad = ComputeOutputGradient(output, targets[i]);
Backpropagate(inputs[i], outputGrad);
// Update weights
UpdateWeights();
}
Console.WriteLine($"Epoch {epoch + 1}/{_epochs}, Loss: {totalLoss / inputs.Length:F6}");
}
}
/// <summary>
/// Computes the mean squared error loss between output and target.
/// </summary>
/// <param name="output">The model’s predicted output.</param>
/// <param name="target">The target output.</param>
/// <returns>The MSE loss value.</returns>
private double ComputeLoss(double[][] output, double[][] target)
{
double loss = 0;
for (int i = 0; i < output.Length; i++)
for (int j = 0; j < output[i].Length; j++)
loss += Math.Pow(output[i][j] - target[i][j], 2);
return loss / (output.Length * output[0].Length);
}
/// <summary>
/// Computes the gradient of the loss with respect to the output.
/// </summary>
private double[][] ComputeOutputGradient(double[][] output, double[][] target)
{
double[][] grad = new double[output.Length][];
for (int i = 0; i < output.Length; i++)
{
grad[i] = new double[output[i].Length];
for (int j = 0; j < output[i].Length; j++)
grad[i][j] = 2 * (output[i][j] - target[i][j]) / (output.Length * output[0].Length);
}
return grad;
}
/// <summary>
/// Performs backpropagation to compute gradients for all weights.
/// </summary>
private void Backpropagate(double[][] input, double[][] outputGrad)
{
// Simplified backpropagation (assumes gradients are stored in components)
double[][] x = _transformer.Forward(input); // Recompute forward pass for intermediate values
// Backprop through second normalization and FFN
double[][] norm2Grad = _transformer.Norm2.NormalizeGradient(outputGrad);
double[][] ffnGrad = _transformer.FFN.Backward(norm2Grad);
double[][] norm1InputGrad = AddResidualGradient(ffnGrad, norm2Grad);
// Backprop through first normalization and attention
double[][] norm1Grad = _transformer.Norm1.NormalizeGradient(norm1InputGrad);
double[][] attnGrad = _transformer.Attention.Backward(norm1Grad, input);
// Store gradients (simplified; in practice, accumulate over batch)
_transformer.Gradients = new[] { attnGrad, ffnGrad }; // Example structure
}
/// <summary>
/// Updates the model weights using the computed gradients.
/// </summary>
private void UpdateWeights()
{
// Update Attention weights (simplified)
UpdateAttentionWeights(_transformer.Attention, _transformer.Gradients[0]);
// Update FFN weights
UpdateFFNWeights(_transformer.FFN, _transformer.Gradients[1]);
// In a real model, update normalization parameters too (gamma, beta)
}
#region Weight Update Helpers
private void UpdateAttentionWeights(MultiHeadAttention attention, double[][] grad)
{
for (int i = 0; i < attention.Weights.Length; i++)
for (int j = 0; j < attention.Weights[i].Length; j++)
for (int k = 0; k < attention.Weights[i][j].Length; k++)
attention.Weights[i][j][k] -= _learningRate * grad[j][k]; // Simplified gradient application
}
private void UpdateFFNWeights(FeedForwardNetwork ffn, double[][] grad)
{
for (int i = 0; i < ffn.W1.Length; i++)
for (int j = 0; j < ffn.W1[i].Length; j++)
ffn.W1[i][j] -= _learningRate * grad[i][j]; // Simplified
for (int i = 0; i < ffn.W2.Length; i++)
for (int j = 0; j < ffn.W2[i].Length; j++)
ffn.W2[i][j] -= _learningRate * grad[i][j];
}
#endregion
#region Utility Methods
private double[][] AddResidualGradient(double[][] x, double[][] residual) =>
x.Select((row, i) => row.Select((v, j) => v + residual[i][j]).ToArray()).ToArray();
#endregion
}
// Extensions to existing classes for training (simplified)
internal class MultiHeadAttention
{
public double[][][] Weights => _weights; // Expose weights for updating
public double[][] Backward(double[][] grad, double[][] input) => grad; // Placeholder
// Existing code...
}
internal class FeedForwardNetwork
{
public double[][] W1 => _w1;
public double[][] W2 => _w2;
public double[][] Backward(double[][] grad) => grad; // Placeholder
// Existing code...
}
internal class LayerNormalization
{
public double[][] NormalizeGradient(double[][] grad) => grad; // Placeholder
// Existing code...
}
}
Again, one could incorporate the Matrix Class if one wanted, but some may prefer to have the visual application of the Arrays instead.