Language Overview

TensaLang is a programmable LLM runtime, a language, compiler, and runtime system that allows you to write model forward passes and sampling in source code, compile through MLIR to CPU or CUDA targets, and execute them directly.

Rather than embedding LLM logic in a library, TensaLang makes the runtime the program itself. You can change attention mechanisms, sampling strategies, tiling, and memory placement without rewriting the compiler, it's all source code.

Key Design Principles

  • Tensors are first-class. Declare shapes, specify precisions, and let the compiler handle memory layouts.
  • Scheduling is source-level. Tile sizes, parallel indices, and memory hints are expressed with with clauses.
  • IR preserves structure. MLIR keeps tensor ops, loop nests, and parallel dimensions intact, not just thread layouts for a specific GPU.
  • Targets are interchangeable. Write once, compile to CUDA, CPU-SIMD, or (planned) MLX and ROCm.

Compilation Pipeline

TensaLang programs flow through several stages from source to executable:

.tl source
    ↓
[tensalang_sugar.py] ← Python frontend tokenizes and parses .tl
    ↓
S-expression IR (.tl.sx)
    ↓
[C++ compiler (bin/tensalang)]
    ├─ Parse S-expression
    ├─ Build MLIR (memref/scf/arith/math/func dialects)
    ├─ Run optimization passes (CSE, canonicalization, inlining)
    ├─ Target-specific lowering:
    │  ├─ CPU: MLIR → LLVM IR → JIT compile
    │  └─ CUDA: GPU dialect → NVVM → cubin serialization
    └─ Execution engine
         ↓
    Token generation output

Key IR Stages

Stage Description
S-expression IR Compact, easy-to-parse intermediate format produced by the Python frontend
MLIR Full compiler IR with dialects (linalg, scf, gpu, arith, memref)
LLVM/NVVM Target-specific code generation

Building from Source

Dependencies

  • LLVM 18 , llvm-config, clang++
  • C++17 , Compatible compiler
  • Python 3.x , For the frontend (tensalang_sugar.py)
  • CUDA Toolkit , For GPU support (include/cuda.h, libcuda.so, libcublas.so)

Build Commands

# Clone the repository
git clone https://github.com/BenChaliah/Tensa-Lang.git
cd Tensa-Lang

# Build everything
./build.sh

# Build outputs:
# - bin/tensalang       (main compiler, ~3.6 MB)
# - bin/tensalang-run   (CLI runner)
# - bin/libtensalang_runtime.a (static runtime library)

Running Examples

Download example weights from HuggingFace first:

git clone https://huggingface.co/DatarusAI/Tensa-Lang models
# Run Llama2 fp16 on CUDA
./bin/tensalang-run examples/llama2_manual_tiling_fp16.tl \
  --model models/llama2_7b/llama2_7b_f16.safetensors \
  --tokenizer models/llama2_7b/tokenizer.json \
  --prompt "Once upon a time" \
  --target cuda \
  --cuda-arch sm_89

# Run Qwen2.5-Coder on CUDA
./bin/tensalang-run examples/qwen25_coder_bf16.tl \
  --model models/Qwen2.5-0.5B-Coder-Instruct/qwen25_0.5b_bf16.safetensors \
  --tokenizer models/Qwen2.5-0.5B-Coder-Instruct/tokenizer.json \
  --prompt "Once upon a time" \
  --target cuda \
  --cuda-arch sm_89

# Run on CPU (auto-tiling)
./bin/tensalang-run examples/llama2_auto_tiling_fp16.tl \
  --model models/llama2_7b/llama2_7b_f16.safetensors \
  --tokenizer models/llama2_7b/tokenizer.json \
  --prompt "Hello" \
  --target cpu \
  --cpu-threads 8

Docker Quick Test

Build the Docker image and run with the helper script, which auto-mounts model and tokenizer paths:

# Build the image
docker build -f docker/Dockerfile -t tensalang:local .

# Run with the helper (auto-mounts model/tokenizer paths)
./docker_command_exec examples/llama2_manual_tiling_fp16.tl \
  --model /path/to/llama2_7b_f16.safetensors \
  --tokenizer /path/to/tokenizer.json \
  --prompt "Once upon a time" \
  --target cuda \
  --cuda-arch sm_89

Set TENSALANG_DOCKER_IMAGE to override the image name. Use TENSALANG_DOCKER_FORCE_GPU=1 or TENSALANG_DOCKER_NO_GPU=1 to force GPU on/off.

Syntax Basics

Lexical Rules

  • Comments: # to end of line
  • Statements: Terminated by newline or ;
  • Blocks: { ... }
  • Identifiers: A-Za-z_ followed by A-Za-z0-9_
  • Numeric literals: Treated as f32 unless cast

Function Definition

fn add(a: f32, b: f32) -> f32 {
  return a + b
}

# With scheduling hints
fn matmul(w: Tensor<f32, [O, I]>, x: Tensor<f32, [I]>) -> Tensor<f32, [O]>
    with parallel=[o], tile=[64] {
  var y: Tensor<f32, [O]>
  y[o] = sum(i) w[o, i] * x[i]
  return y
}

Control Flow

# For loop (range)
for i in 0..N {
  x[i] = i as f32
}

# While loop
while i < n {
  i = i + 1
}

# If expression
x = if cond { a } else { b }

# Break/continue supported
for i in 0..N {
  if i == 5 { break }
}

Type System

Scalar Types

Category Types
Signed integers i8, i16, i32, i64
Unsigned integers u8, u16, u32, u64
Floats f16, bf16, f32, f64
Other bool, string

Tensor Types

# Static dimensions
Tensor<f32, [128, 64]>

# Symbolic dimensions (bound at runtime)
Tensor<f16, [B, L, D]>

# Unknown dimension
Tensor<f32, [_, D]>

# Arrays (rank-1 memref)
Array<i32, [N]>

Structs

struct KVCache {
  key: Tensor<f16, [L, SeqLen, D]>
  value: Tensor<f16, [L, SeqLen, D]>
}

Tensor Operations

# Allocation
var y: Tensor<f32, [N]> = zeros([N])
var x: Tensor<f16, [N]> = zeros_f16([N])

# Indexing (creates implicit loops)
y[i] = x[i] + 1.0
M[i, j] = M[i, j] * 2

# Reductions
y[o] = sum(i) w[o, i] * x[i]   # Sum reduction
m[h] = max(t) att[h, t]           # Max reduction

# Dimension access
const n = dim(x, 0)   # Returns f32, cast for loop bounds
const len = len(x)     # Shorthand for dim(x, 0)

Loading from Safetensors

# The compiler detects safetensors indexing and emits runtime loader
var w: Tensor<f16, [O, I]> = weights["layer.0.weight"]

# Dynamic key construction
const wq: Tensor<f16, [D, D]> = st[layer_key("layers.", l, ".wq")]

Scheduling Hints

Scheduling hints are specified with the with clause after the function signature. These are preferences, not directives, the compiler may clamp or ignore them based on hardware constraints.

fn attention_f16(...) -> Tensor<f32, [D]>
    with tile=[8, 64], parallel=[h, t],
         memory={key_cache: shared_mem, value_cache: shared_mem} {
  # ...
}

Available Hints

Hint Description
tile=[x, y, ...] Preferred CUDA block sizes for parallelized assignments
parallel=[i, j, ...] Index variables to map to GPU threads/blocks
memory={...} Memory placement hints (e.g. memory={key_cache: shared_mem, value_cache: shared_mem})

Builtins

Tensor Constructors

zeros([H, T])       # f32 tensor filled with zeros
ones([N])           # f32 tensor filled with ones
zeros_f16([H, T])   # f16 tensor filled with zeros
ones_f16([N])       # f16 tensor filled with ones

Math Operations

exp(x), log(x), tanh(x), sqrt(x)
abs(x), floor(x), ceil(x)
sin(x), cos(x), pow(base, exp)

Overridable Builtins

If you don't define your own, the compiler provides default implementations for:

  • softmax , Normalizes last axis
  • layernorm , RMSNorm-like normalization

Define a function with the same name in your .tl file to override.

Runtime Hooks

Extern functions bind TensaLang to the C++ runtime for I/O, memory management, and tokenization.

# Safetensors loading
extern fn safetensors_open(path: string) -> i64 = tensalang_safetensors_open
extern fn safetensors_close(handle: i64) = tensalang_safetensors_close

# Tokenization
extern fn tokenize(text: string) -> Tensor<i32, [P]> = tensalang_tokenize
extern fn decode(token: i32) -> string = tensalang_decode

# Arena allocator (reuses memory each token)
extern fn arena_begin() = tensalang_arena_begin
extern fn arena_reset() = tensalang_arena_reset
extern fn arena_end() = tensalang_arena_end

CUDA Target

The CUDA backend compiles .tl code to GPU kernels through MLIR's GPU dialect and NVVM.

Key Optimizations

  • Fused Attention: Controlled by TENSALANG_FUSED_ATTENTION (0=disabled, 1=single-kernel, 2=two-stage default)
  • cuBLAS GEMV: Pattern-matched matvec dispatches to cuBLAS for compatible layouts
  • Managed Memory: CUDA unified memory for host-GPU transfers

Environment Variables

TENSALANG_CUDA_ARCH=sm_89        # GPU compute capability
TENSALANG_FUSED_ATTENTION=2      # Attention fusion mode
TENSALANG_CUBLAS_MATVEC=1        # cuBLAS dispatch
TENSALANG_CUDA_SINGLE_STREAM=1   # Single stream mode

CPU-SIMD Target

The CPU backend lowers MLIR to LLVM IR with vectorization passes for SIMD execution.

Features

  • • AVX2/AVX-512 vectorization for math operations
  • • Parallel loop execution with OpenMP-style threading
  • • Optimized SIMD kernels for attention and RMSNorm

MLX / ROCm (Planned)

Future targets are planned for the MLIR backend abstraction:

  • MLX: Metal Performance Shaders and custom MSL kernels for Apple Silicon
  • ROCm: AMD GPUs via HIP/AMDGPU using MLIR's ROCDL dialect

Llama2 Implementation

The repository includes complete Llama2 implementations in FP16 and FP32. Key components:

  • RMSNorm with FP16 weights
  • RoPE positional embeddings
  • Flash Attention with KV cache
  • Top-p sampling with temperature
# Core functions from llama2_example.tl

fn matmul_vec_f16(w: Tensor<f16, [O, I]>, x: Tensor<f32, [I]>) -> Tensor<f32, [O]>
    with parallel=[o] {
  var y: Tensor<f32, [O]>
  y[o] = sum(i) (w[o, i] as f32) * x[i]
  return y
}

fn rmsnorm(x: Tensor<f32, [D]>, weight: Tensor<f16, [D]>, eps: f32) -> Tensor<f32, [D]> {
  const n = dim(x, 0)
  var ss: Tensor<f32, [1]>
  ss[0] = sum(i) x[i] * x[i]
  const inv: f32 = 1.0 / sqrt(ss[0] / n + eps)

  var y: Tensor<f32, [D]>
  y[i] = (weight[i] as f32) * (x[i] * inv)
  return y
}

fn silu_mul(a: Tensor<f32, [N]>, b: Tensor<f32, [N]>) -> Tensor<f32, [N]> {
  var out: Tensor<f32, [N]>
  out[i] = a[i] * (1.0 / (1.0 + exp(-a[i]))) * b[i]
  return out
}
# Attention with KV cache and scheduling hints
fn attention_f16(q: Tensor<f32, [D]>,
                 key_cache: Tensor<f16, [L, SeqLen, KvDim]>,
                 value_cache: Tensor<f16, [L, SeqLen, KvDim]>,
                 layer: i32, pos: i32, H: i32, KvH: i32, scale: f32) -> Tensor<f32, [D]>
    with tile=[8, 64], parallel=[h, t],
         memory={key_cache: shared_mem, value_cache: shared_mem} {
  const D: i32 = dim(q, 0) as i32
  const Dh: i32 = D / H
  const kv_mul: i32 = H / KvH

  var att: Tensor<f32, [H, SeqLen]> = zeros([H, SeqLen])
  var xb_att: Tensor<f32, [D]> = zeros([D])

  for h in 0..H {
    # Compute attention scores
    for t in 0..AttLen {
      var score: Tensor<f32, [1]>
      score[0] = sum(d) q[h * Dh + d] *
                 (key_cache[layer, t, (h / kv_mul) * Dh + d] as f32) * scale
      att[h, t] = score[0]
    }

    # Softmax + weighted sum over values
    var weights: Tensor<f32, [SeqLen]> = softmax(att[h, :])
    for i in 0..Dh {
      xb_att[h * Dh + i] = sum(t) weights[t] *
        (value_cache[layer, t, (h / kv_mul) * Dh + i] as f32)
    }
  }
  return xb_att
}

# Full transformer forward pass
fn forward(st: i64, token: i32, pos: i32, H: i32, KvH: i32,
           key_cache: ref Tensor<f16, [L, SeqLen, KvDim]>,
           value_cache: ref Tensor<f16, [L, SeqLen, KvDim]>) -> Tensor<f32, [Vocab]> {
  const wte: Tensor<f16, [Vocab, D]> = st["token_embedding_table"]
  const wcls: Tensor<f16, [Vocab, D]> = st["wcls"]
  const rms_final: Tensor<f16, [D]> = st["rms_final_weight"]
  const scale: f32 = 1.0 / sqrt(Dh as f32)

  var x: Tensor<f32, [D]> = zeros([D])
  x[d] = wte[token, d] as f32

  for l in 0..L {
    const wq: Tensor<f16, [D, D]> = st[layer_key("layers.", l, ".wq")]
    const wk: Tensor<f16, [KvDim, D]> = st[layer_key("layers.", l, ".wk")]
    const wv: Tensor<f16, [KvDim, D]> = st[layer_key("layers.", l, ".wv")]
    const wo: Tensor<f16, [D, D]> = st[layer_key("layers.", l, ".wo")]
    const w1: Tensor<f16, [Hidden, D]> = st[layer_key("layers.", l, ".w1")]
    const w2: Tensor<f16, [D, Hidden]> = st[layer_key("layers.", l, ".w2")]
    const w3: Tensor<f16, [Hidden, D]> = st[layer_key("layers.", l, ".w3")]
    const rms_att: Tensor<f16, [D]> = st[layer_key("layers.", l, ".rms_att_weight")]
    const rms_ffn: Tensor<f16, [D]> = st[layer_key("layers.", l, ".rms_ffn_weight")]

    # Attention block
    var xb: Tensor<f32, [D]> = rmsnorm(x, rms_att, 1e-5)
    var q: Tensor<f32, [D]> = matmul_vec_f16(wq, xb)
    var k: Tensor<f32, [KvDim]> = matmul_vec_f16(wk, xb)
    var v: Tensor<f32, [KvDim]> = matmul_vec_f16(wv, xb)

    rope_kv_f16(q, k, v, key_cache, value_cache, l, pos, H, KvH)
    var xb_att: Tensor<f32, [D]> = attention_f16(q, key_cache, value_cache, l, pos, H, KvH, scale)
    var xb2: Tensor<f32, [D]> = matmul_vec_f16(wo, xb_att)
    add_inplace(x, xb2)

    # MLP block (SwiGLU)
    var xb_ffn: Tensor<f32, [D]> = rmsnorm(x, rms_ffn, 1e-5)
    var hb: Tensor<f32, [Hidden]> = matmul_vec_f16(w1, xb_ffn)
    var hb2: Tensor<f32, [Hidden]> = matmul_vec_f16(w3, xb_ffn)
    var h: Tensor<f32, [Hidden]> = silu_mul(hb, hb2)
    add_inplace(x, matmul_vec_f16(w2, h))
  }

  # Final normalization and output projection
  var xn: Tensor<f32, [D]> = rmsnorm(x, rms_final, 1e-5)
  var logits: Tensor<f32, [Vocab]> = matmul_vec_f16(wcls, xn)
  return logits
}

Qwen2.5-Coder

The Qwen2.5-Coder-0.5B-Instruct implementation demonstrates BF16 inference with grouped query attention:

# Model configuration
const QWEN_H: i32 = 14              # Attention heads
const QWEN_KVH: i32 = 2            # KV heads (GQA)
const QWEN_L: i32 = 24             # Layers
const QWEN_SEQ: i32 = 32768        # Max sequence length
const QWEN_ROPE_THETA: f32 = 1000000.0  # RoPE base

Compiler Architecture

Project Structure

src/
├── tensalang_sugar.py    # Python frontend (.tl → S-expr)
├── codegen.cpp           # MLIR code generation (6,140 lines)
├── runtime_core.cpp      # Safetensors, tokenizer, sampling
├── backend.cpp           # Backend abstraction layer
└── helpers.cpp           # S-expr parsing utilities

Targets/
├── CUDA/
│   ├── runtime_cuda.cpp  # CUDA runtime wrappers
│   └── gpu_serialize.cpp # GPU kernel → cubin
└── CPU/
    └── runtime_cpu.cpp   # CPU with SIMD optimizations

MLIR Dialects

TensaLang uses MLIR as its core IR, enabling powerful optimization passes and target-specific lowering.

Dialects Used

Dialect Purpose
memref Memory references and tensor storage
scf Structured control flow (for, while, if)
arith Arithmetic operations
math Mathematical functions (exp, log, etc.)
gpu GPU kernel launch and thread mapping
nvvm NVIDIA-specific operations for CUDA
llvm LLVM IR generation for CPU target