Skip to main content

validate_multihead_shapes

Function validate_multihead_shapes 

Source
fn validate_multihead_shapes(
    q: &[f64],
    q_rows: usize,
    q_total_cols: usize,
    k: &[f64],
    k_rows: usize,
    k_total_cols: usize,
    v: &[f64],
    v_rows: usize,
    v_total_cols: usize,
    n_heads: usize,
) -> Result<(), String>