From 73bc1987c8f889c347e033dbe9978f2a3d658aed Mon Sep 17 00:00:00 2001 From: 7shi <7shi@live.jp> Date: Wed, 22 Apr 2026 21:38:22 +0900 Subject: [PATCH 1/2] [fix] enable -DBITNET_X86_TL2=ON when -q tl2 is specified --- setup_env.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/setup_env.py b/setup_env.py index 3bf5fb8f7..628d35db6 100644 --- a/setup_env.py +++ b/setup_env.py @@ -239,6 +239,8 @@ def signal_handler(sig, frame): if __name__ == "__main__": signal.signal(signal.SIGINT, signal_handler) args = parse_args() + if args.quant_type == "tl2": + COMPILER_EXTRA_ARGS["x86_64"] = ["-DBITNET_X86_TL2=ON"] Path(args.log_dir).mkdir(parents=True, exist_ok=True) logging.basicConfig(level=logging.INFO) main() From f0c291873737e5a5e70e5160474e045c017ca1ac Mon Sep 17 00:00:00 2001 From: 7shi <7shi@live.jp> Date: Thu, 23 Apr 2026 02:35:49 +0900 Subject: [PATCH 2/2] [feat] add TL2 support for Falcon3 family models --- setup_env.py | 2 ++ utils/codegen_tl2.py | 30 +++++++++++++++++++++++++++++- 2 files changed, 31 insertions(+), 1 deletion(-) diff --git a/setup_env.py b/setup_env.py index 628d35db6..1e2957dd7 100644 --- a/setup_env.py +++ b/setup_env.py @@ -190,6 +190,8 @@ def gen_code(): shutil.copyfile(os.path.join(pretuned_kernels, "bitnet-lut-kernels-tl2.h"), "include/bitnet-lut-kernels.h") if get_model_name() == "bitnet_b1_58-large": run_command([sys.executable, "utils/codegen_tl2.py", "--model", "bitnet_b1_58-large", "--BM", "256,128,256", "--BK", "96,192,96", "--bm", "32,32,32"], log_step="codegen") + elif get_model_name().startswith("Falcon3"): + run_command([sys.executable, "utils/codegen_tl2.py", "--model", get_model_name(), "--BM", "256,128,256,128", "--BK", "96,96,96,96", "--bm", "32,32,32,32"], log_step="codegen") elif get_model_name() in llama3_f3_models: run_command([sys.executable, "utils/codegen_tl2.py", "--model", "Llama3-8B-1.58-100B-tokens", "--BM", "256,128,256,128", "--BK", "96,96,96,96", "--bm", "32,32,32,32"], log_step="codegen") elif get_model_name() == "bitnet_b1_58-3B": diff --git a/utils/codegen_tl2.py b/utils/codegen_tl2.py index 4d9408123..de35b1efe 100644 --- a/utils/codegen_tl2.py +++ b/utils/codegen_tl2.py @@ -690,7 +690,35 @@ def get_three_k_two_k(K, bk): "Llama3-8B-1.58-100B-tokens" : [[14336, 4096], [4096, 14336], [1024, 4096], - [4096, 4096]] + [4096, 4096]], + "Falcon3-10B-Instruct-1.58bit" : [[23040, 3072], + [3072, 23040], + [1024, 3072], + [3072, 3072]], + "Falcon3-10B-1.58bit" : [[23040, 3072], + [3072, 23040], + [1024, 3072], + [3072, 3072]], + "Falcon3-7B-Instruct-1.58bit" : [[23040, 3072], + [3072, 23040], + [1024, 3072], + [3072, 3072]], + "Falcon3-7B-1.58bit" : [[23040, 3072], + [3072, 23040], + [1024, 3072], + [3072, 3072]], + "Falcon3-3B-Instruct-1.58bit" : [[9216, 3072], + [3072, 9216], + [1024, 3072], + [3072, 3072]], + "Falcon3-3B-1.58bit" : [[9216, 3072], + [3072, 9216], + [1024, 3072], + [3072, 3072]], + "Falcon3-1B-Instruct-1.58bit" : [[8192, 2048], + [2048, 8192], + [1024, 2048], + [2048, 2048]], } parser = argparse.ArgumentParser(description='gen impl')