sanskxr02 commited on
Commit
dfdc6d9
·
verified ·
1 Parent(s): 7e4eab8

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +68 -0
README.md ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language:
4
+ - en
5
+ tags:
6
+ - attention
7
+ - temporal-reasoning
8
+ - time-series
9
+ - inductive-bias
10
+ - plug-and-play
11
+ ---
12
+ # TemporalSelfAttention - A Time-Biased Attention Module
13
+
14
+ > Give Transformers a sense of time - not by scaling, but by structure.
15
+
16
+ ---
17
+
18
+ ## Why?
19
+
20
+ Standard attention treats all tokens equally in time.
21
+ This works for syntax, but breaks for:
22
+
23
+ - Temporal event ordering
24
+ - Causal reasoning
25
+ - Timeline consistency
26
+ - Long-range narrative coherence
27
+
28
+ 💡 Insight: These models *simulate* time via token position. We inject it *structurally* with a tiny inductive bias.
29
+
30
+ ---
31
+
32
+ ## Core Equation
33
+
34
+
35
+ The time-aware attention score is computed as:
36
+
37
+ $$
38
+ \text{score}_{ij} = \frac{Q_i \cdot K_j^\top}{\sqrt{d_k}} + \gamma \cdot f(t_j - t_i)
39
+ $$
40
+
41
+ ### Notation
42
+
43
+ | Symbol | Description |
44
+ |-----------------|-------------|
45
+ | \\( \text{score}_{ij} \\) | Attention score between query at position \\( i \\) and key at position \\( j \\) |
46
+ | \\( Q_i \\) | Query vector for position \\( i \\) |
47
+ | \\( K_j \\) | Key vector for position \\( j \\) |
48
+ | \\( d_k \\) | Dimension of key vectors |
49
+ | \\( \gamma \\) | Learnable time bias strength |
50
+ | \\( f(\cdot) \\) | Time difference function |
51
+ | \\( t_j - t_i \\) | Relative time difference |
52
+
53
+
54
+ ## How To Use
55
+
56
+ ```python
57
+ from temporal_attention import TemporalSelfAttention
58
+
59
+ model = TemporalSelfAttention(
60
+ embed_dim=64,
61
+ num_heads=1,
62
+ bias_type="linear", # or 'gaussian'
63
+ gamma=1.0,
64
+ causal=False
65
+ )
66
+
67
+ # x: (B, T, D), timestamps: (B, T)
68
+ output, weights = model(x, timestamps)