diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc index 7d771f1504640a5756486abcdf1bcf21e45e92f0..cebc7817a292011587f4941dfff502f5e5c98cbd 100644 Binary files a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc and b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc index 6b4fa56b8116b17718afc9b62e3ee61a10fc8f85..4f17eb8838cb893f61b2f42a345df9f69eb7e8bf 100644 Binary files a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc and b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc index 6c28ccabf5d15b099f889bcc9c674704f2d2a35e..fdd84db8007b53b4dd4368dc01ad9296d4b95830 100644 Binary files a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc and b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc index c9a18fa8106e824330c12089d84cc6c7758c3fd7..67e3eba29657209ffccf401d6e8c340a28057265 100644 Binary files a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc and b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc index 405282a0b844ab76a98031c22d5f064bf48592c3..f78fec2e6a76348b50183b154618fd095d197d13 100644 Binary files a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc and b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..0a843146f753e99ec745a7e2deb9c3db543a3482 --- /dev/null +++ b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9c7b5ae8af9477be3049ba1ae6af3f9e2d8bf82979fa9e9632c485a8d49f532a +size 64503960 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so deleted file mode 100755 index d6fdbaf71bdf580e7d5816f235b28ed46359d79c..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:9875e2f9570a0bca9725184062f07e0f904e87a514d25175cda87e0c95a8666a -size 64503976 diff --git a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_ops.py b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_ops.py index 92fecce393e361dda107f47aa06cd5df2924281d..c1d217d97eaddf8812c504cd7ca9656b8b72fba4 100644 --- a/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_ops.py +++ b/build/torch27-cxx11-cu118-x86_64-linux/causal_conv1d/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _causal_conv1d_db03e28_dirty -ops = torch.ops._causal_conv1d_db03e28_dirty +from . import _causal_conv1d_90f5a60 +ops = torch.ops._causal_conv1d_90f5a60 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_causal_conv1d_db03e28_dirty::{op_name}" \ No newline at end of file + return f"_causal_conv1d_90f5a60::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc index 1551eb5b53dbc3c61977f52179b70a089fd1df28..786bd24d264db90632165d514db4f5521d8133e0 100644 Binary files a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc and b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc index a6239878edc97c5a6d0759d26dcb32e98688bd1b..c9d95dcbdfd288de6d7077fd18259eb3bddb4958 100644 Binary files a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc and b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc index bb269fc8c92ec2ee2ee74e399e351e3a2e068bf1..297eb3e2197477c1df29aaeaff32a0c4b4ba8611 100644 Binary files a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc and b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc index da4fddd46d1b1dc938f04afa636a2d90c3fc3391..40dd2dec4c9da46e89fe46f6974c794ad6432a52 100644 Binary files a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc and b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc index 763ed3ed177891c42196ff33a9a3804929f1d2e7..c6d47e6accd7dc9a8d27df0e15f3363d6579b752 100644 Binary files a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc and b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..770b6a40b54b2b7dc3ff89d20de6dae3cfd5a06a --- /dev/null +++ b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:299bc47bf7fdea21eb71f9b0d0cd329a32e792106029d4fd5d6c637c76b9c6f7 +size 64213568 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so deleted file mode 100755 index 74940a8b1b7e690a9ec507e8e3c3095bbf180a27..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:f7c97c4b57e74c91956eca7b812f88fa647a2c3676cc52b2bf585f79cc73d9ea -size 64213584 diff --git a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py index 92fecce393e361dda107f47aa06cd5df2924281d..c1d217d97eaddf8812c504cd7ca9656b8b72fba4 100644 --- a/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py +++ b/build/torch27-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _causal_conv1d_db03e28_dirty -ops = torch.ops._causal_conv1d_db03e28_dirty +from . import _causal_conv1d_90f5a60 +ops = torch.ops._causal_conv1d_90f5a60 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_causal_conv1d_db03e28_dirty::{op_name}" \ No newline at end of file + return f"_causal_conv1d_90f5a60::{op_name}" \ No newline at end of file diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc index e385c426b54adbe60e6f705951227631db9cfd18..18d3139112b33e61419b6a2728a795c6a358861e 100644 Binary files a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc and b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc index 9d38e35200729fde9a6a49b934268706adf8f489..2857339f18ec0febb5d24266e3d91fbbe2fa820c 100644 Binary files a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc and b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc index 1083c3d0e18bb70e8900b92544e5faa660765504..bef01edf5863f420831c4fd7b62444df502bf29e 100644 Binary files a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc and b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc index e98a981b1953e8a0c05eb5e9ae9a371268e33adf..eb24c381faa7b75b6519cd48d85261bec5d03f1d 100644 Binary files a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc and b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc index a6a734a841d689e2f2d71ae59cee104c82aa09f5..7283c7bf629a1f669d5802a0aecc629cfcda5eb0 100644 Binary files a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc and b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..bc680b2d853196f8878bb9cb5f6d73bce3cecb80 --- /dev/null +++ b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:859247ab0b3e7852c4e1a6ac76f3c62b3aebea729241058075eb2e6f29139a50 +size 90656256 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so deleted file mode 100755 index 3f3829da2340c302763d70d8470582cf239f6467..0000000000000000000000000000000000000000 --- a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:5d631b6843ce08b07172d0d91c3eb55c1d4d5b45839337e1eeeeb430054f97ce -size 102460944 diff --git a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py index 92fecce393e361dda107f47aa06cd5df2924281d..c1d217d97eaddf8812c504cd7ca9656b8b72fba4 100644 --- a/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py +++ b/build/torch27-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _causal_conv1d_db03e28_dirty -ops = torch.ops._causal_conv1d_db03e28_dirty +from . import _causal_conv1d_90f5a60 +ops = torch.ops._causal_conv1d_90f5a60 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_causal_conv1d_db03e28_dirty::{op_name}" \ No newline at end of file + return f"_causal_conv1d_90f5a60::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc index 19b05560416a36434d6df3ca202a01344f56e5b8..5bd9fedd58db003f26a87e469371af220238bc87 100644 Binary files a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc index caae805f8c9f085caad051c4bd6c110f4f8a9c02..e16ce1c2addaebc17ff173ee0a8259c78d99ce87 100644 Binary files a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc and b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc index 826f93b5fd568a8a03bb8cbed2994c7a9e7f15ca..de29433be5c61a237f9a695103e2b93446ebf2da 100644 Binary files a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc and b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc index fffcd669af4726e9d0adf64fa8d1fe6159f2fe13..99d5a36d1bd8f837d74297fba4666c6b309826f4 100644 Binary files a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc and b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc index cbf994dd52d0457bdb89dfd81378f9c6a9ed0188..7ca3550517acb79d98fe21204e7f30a305fac041 100644 Binary files a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc and b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..2d221d1e9e61578aa03eeddcf419c9e3599f852e --- /dev/null +++ b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e0f4f7ddbf2822ad6dc2a89b513a4eb74700cff343ee3728952a8e35616978cf +size 64213792 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so deleted file mode 100755 index 87bee4c541294833d811aa201e3b5d10aaa2a479..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:dd6b0804c4998af6b6dd9f64dfdabe98ac3db66fe78f6e5ea1a4db32d667a431 -size 64213808 diff --git a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py index 92fecce393e361dda107f47aa06cd5df2924281d..c1d217d97eaddf8812c504cd7ca9656b8b72fba4 100644 --- a/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py +++ b/build/torch28-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _causal_conv1d_db03e28_dirty -ops = torch.ops._causal_conv1d_db03e28_dirty +from . import _causal_conv1d_90f5a60 +ops = torch.ops._causal_conv1d_90f5a60 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_causal_conv1d_db03e28_dirty::{op_name}" \ No newline at end of file + return f"_causal_conv1d_90f5a60::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc index 6b5eb0a6b50d00f52d6da06563d1d28fc1e564b8..c5d1f96e4cc335fc206a51c3a2576f6ee6d7317a 100644 Binary files a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc index f5cf938d3e1eeebbb152880a32435e91b9ca5b5f..38bc20e5f91e6eb89c5f3153fc38fa8123929cd0 100644 Binary files a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc and b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc index 2610dba23471323009b19ced6f9e2e669ee5b7e9..f20f606aee4628b4f8f85d4cf90b889fa729a7c2 100644 Binary files a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc and b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc index 39d2a0b66b76159be8e52cfcf6c75af7826f82fa..d510f8c40a7edd1a7aea534463c4e25509ed5b9b 100644 Binary files a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc and b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc index c5eaa70fc7ef0ce097eab4af9b3aac82a72eba7a..f5a0cb66c9d790d97b67a50309819c4264f221a0 100644 Binary files a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc and b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..25359c38281dc9dcf6d69db4ade464f8b23896e9 --- /dev/null +++ b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:47aabbc543b8f9899f8447973962253cf0201f88c7dff965b2b387bf369fb204 +size 90660576 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so deleted file mode 100755 index ba9129f6115ce77aedda557ae5da82241b51ce9b..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:32eb5c27a38a2f9cad4076e3d3e47b1cfe1473604b7867eb87259872c49fe64a -size 102465272 diff --git a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py index 92fecce393e361dda107f47aa06cd5df2924281d..c1d217d97eaddf8812c504cd7ca9656b8b72fba4 100644 --- a/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py +++ b/build/torch28-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _causal_conv1d_db03e28_dirty -ops = torch.ops._causal_conv1d_db03e28_dirty +from . import _causal_conv1d_90f5a60 +ops = torch.ops._causal_conv1d_90f5a60 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_causal_conv1d_db03e28_dirty::{op_name}" \ No newline at end of file + return f"_causal_conv1d_90f5a60::{op_name}" \ No newline at end of file diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc index 79b116d0a0cb0e8aa4a404292f8bc2c14dcbd25e..10aec904bf54ef45b6699d8e9f9fb97e90396282 100644 Binary files a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc and b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc index aa7a049efac4ba97aac5288678ce588364d5868b..738df5f4deb4701c9799066d7995f931b020feff 100644 Binary files a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc and b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc index 2a730525ba8d9fe9b2f313196e086388d4c3f1d6..8fe96fc4607af616511dcb714b268e2bd6baa877 100644 Binary files a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc and b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc index 4c2596be71aa4fd852dc7077ec79a843998e862f..9aee2fe5ecd3870a4f2d21056ea8de606eb5322a 100644 Binary files a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc and b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc index 045a8072a22dc785881ebdcc4173848ef8259a3d..a03628fc814c40091c6ec925e8a5c1f7279d5e3d 100644 Binary files a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc and b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..a139496a5202bd49b0e4a150a0b6367e629b5368 --- /dev/null +++ b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bd59adbb4d01c1d3b986f6ab63e3873a57cc8f4cae885c7ad83fe7c9df16b395 +size 97498136 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so deleted file mode 100755 index 481a6e6b7c040bdd092adfc01329d856e55af62f..0000000000000000000000000000000000000000 --- a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/_causal_conv1d_db03e28_dirty.abi3.so +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:6d95e7ebb7ad5f7881f5876269e47f5b28dd22b4de7c2d598476e037f67937c6 -size 110109736 diff --git a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/_ops.py b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/_ops.py index 92fecce393e361dda107f47aa06cd5df2924281d..c1d217d97eaddf8812c504cd7ca9656b8b72fba4 100644 --- a/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/_ops.py +++ b/build/torch28-cxx11-cu129-x86_64-linux/causal_conv1d/_ops.py @@ -1,9 +1,9 @@ import torch -from . import _causal_conv1d_db03e28_dirty -ops = torch.ops._causal_conv1d_db03e28_dirty +from . import _causal_conv1d_90f5a60 +ops = torch.ops._causal_conv1d_90f5a60 def add_op_namespace_prefix(op_name: str): """ Prefix op by namespace. """ - return f"_causal_conv1d_db03e28_dirty::{op_name}" \ No newline at end of file + return f"_causal_conv1d_90f5a60::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebc4c00dbc57ecd70aace2d48ad170591723c5fd Binary files /dev/null and b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1094241fa76fea27d8f48a8f478357ef4e76a63d Binary files /dev/null and b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e145721f7fa7efcd1b17be15328f4948d8e70294 Binary files /dev/null and b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b33b867e68979585cf08ba23067d07a77ace8e8c Binary files /dev/null and b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b549cc522fb634296f8309c87026696ec0c43add Binary files /dev/null and b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..95c7e1c5eb4abbfd1c02c342421e2f1630ee52b0 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:68f4c1657b7f67aac24a796a9d2379ccc67751b93e28950ad9e0849256cedcc4 +size 64217976 diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c1d217d97eaddf8812c504cd7ca9656b8b72fba4 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_90f5a60 +ops = torch.ops._causal_conv1d_90f5a60 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_90f5a60::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_interface.py b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/cpp_functions.py b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch29-cxx11-cu126-x86_64-linux/causal_conv1d/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94566f49416a5d8427c82e759945901e342d642d Binary files /dev/null and b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfea2e149ced6b52a04e5e09de4dd3eca9736c4d Binary files /dev/null and b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2de185bb18791e120a4294e7db8b30bfebfcb06 Binary files /dev/null and b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b76dc8a3cb047163285f158388b3aadcd7fbb02 Binary files /dev/null and b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dc7e8d3e1329f4ab572a461f292c4cdb7c1f8f2e Binary files /dev/null and b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..7c67b3d0d50fc3d020ffa125e1e138f5145ad885 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e793340f434e5927b198c3079bb7affd92c0061f521cac38e0b071cea9df1b88 +size 90660664 diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c1d217d97eaddf8812c504cd7ca9656b8b72fba4 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_90f5a60 +ops = torch.ops._causal_conv1d_90f5a60 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_90f5a60::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_interface.py b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/cpp_functions.py b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch29-cxx11-cu128-x86_64-linux/causal_conv1d/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e00d508d004ebe3eafe214d2b1b2ec2a44090d5c --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__init__.py @@ -0,0 +1,4 @@ +from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update +from .causal_conv1d_varlen import causal_conv1d_varlen_states + +__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"] diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94823edf987983958e30a5ce5dab0c9d2e3ca3e2 Binary files /dev/null and b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__pycache__/__init__.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..76d10867dc70fe0b9cbf357663a90a19901119cf Binary files /dev/null and b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__pycache__/_ops.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ee5ae99bfc4481e93435fd239e6ea9d1fcb9f58 Binary files /dev/null and b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_interface.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a5ca0138b562ca1a9e31af0f1e4674ab8deb6a1 Binary files /dev/null and b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__pycache__/causal_conv1d_varlen.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc new file mode 100644 index 0000000000000000000000000000000000000000..548d455c04a3776fd22333448c658ec952e67a21 Binary files /dev/null and b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/__pycache__/cpp_functions.cpython-313.pyc differ diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so new file mode 100755 index 0000000000000000000000000000000000000000..f1b8b0235f85cbe20b7862a4279454e018cb6cd2 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/_causal_conv1d_90f5a60.abi3.so @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef1065ff555921a0982e3c66ba18fcf771da03395373ec41b4802a9a769a761f +size 58210488 diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/_ops.py b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/_ops.py new file mode 100644 index 0000000000000000000000000000000000000000..c1d217d97eaddf8812c504cd7ca9656b8b72fba4 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/_ops.py @@ -0,0 +1,9 @@ +import torch +from . import _causal_conv1d_90f5a60 +ops = torch.ops._causal_conv1d_90f5a60 + +def add_op_namespace_prefix(op_name: str): + """ + Prefix op by namespace. + """ + return f"_causal_conv1d_90f5a60::{op_name}" \ No newline at end of file diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/causal_conv1d_interface.py b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/causal_conv1d_interface.py new file mode 100644 index 0000000000000000000000000000000000000000..d46d56c415e91e16a475a7261e01658b2259d377 --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/causal_conv1d_interface.py @@ -0,0 +1,242 @@ +# Copyright (c) 2024, Tri Dao. + +import torch +import torch.nn.functional as F + +from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function + + +class CausalConv1dFn(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, + ): + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(2) != 1 and x.stride(1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None + if seq_idx is not None: + assert ( + initial_states is None + ), "initial_states must be None if seq_idx is not None" + assert ( + not return_final_states + ), "If seq_idx is not None, we don't return final_states_out" + seq_idx = seq_idx.contiguous() if seq_idx is not None else None + if initial_states is not None and ( + initial_states.stride(2) != 1 and initial_states.stride(1) != 1 + ): + initial_states = initial_states.contiguous() + if return_final_states: + assert ( + x.stride(1) == 1 + ), "Only channel-last layout support returning final_states_out" + if final_states_out is not None: + assert ( + final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1 + ) + else: + batch, dim, seqlen = x.shape + width = weight.shape[1] + final_states_out = torch.empty( + batch, width - 1, dim, device=x.device, dtype=x.dtype + ).transpose(1, 2) + else: + final_states_out = None + ctx.activation = activation in ["silu", "swish"] + out = causal_conv1d_fwd_function( + x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation + ) + ctx.save_for_backward(x, weight, bias, seq_idx, initial_states) + ctx.return_final_states = return_final_states + ctx.return_dinitial_states = ( + initial_states is not None and initial_states.requires_grad + ) + return out if not return_final_states else (out, final_states_out) + + @staticmethod + def backward(ctx, dout, *args): + x, weight, bias, seq_idx, initial_states = ctx.saved_tensors + dfinal_states = args[0] if ctx.return_final_states else None + if dout.stride(2) != 1 and dout.stride(1) != 1: + dout = dout.contiguous() + # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the + # backward of conv1d with the backward of chunk). + # Here we just pass in None and dx will be allocated in the C++ code. + dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function( + x, + weight, + bias, + dout, + seq_idx, + initial_states, + dfinal_states, + None, + ctx.return_dinitial_states, + ctx.activation, + ) + return ( + dx, + dweight, + dbias if bias is not None else None, + None, + dinitial_states if initial_states is not None else None, + None, + None, + None, + ) + + +def causal_conv1d_fn( + x, + weight, + bias=None, + seq_idx=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + seq_idx: (batch, seqlen) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1), to be written to + activation: either None or "silu" or "swish" + + out: (batch, dim, seqlen) + """ + return CausalConv1dFn.apply( + x, + weight, + bias, + seq_idx, + initial_states, + return_final_states, + final_states_out, + activation, + ) + + +def causal_conv1d_ref( + x, + weight, + bias=None, + initial_states=None, + return_final_states=False, + final_states_out=None, + activation=None, +): + """ + x: (batch, dim, seqlen) + weight: (dim, width) + bias: (dim,) + initial_states: (batch, dim, width - 1) + final_states_out: (batch, dim, width - 1) + + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + x = x.to(weight.dtype) + seqlen = x.shape[-1] + dim, width = weight.shape + if initial_states is None: + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim) + else: + x = torch.cat([initial_states, x], dim=-1) + out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim) + out = out[..., :seqlen] + if return_final_states: + final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to( + dtype_in + ) # (batch, dim, width - 1) + if final_states_out is not None: + final_states_out.copy_(final_states) + else: + final_states_out = final_states + out = (out if activation is None else F.silu(out)).to(dtype=dtype_in) + return out if not return_final_states else (out, final_states_out) + + +def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + activation = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + out = causal_conv1d_update_function( + x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices + ) + if unsqueeze: + out = out.squeeze(-1) + return out + + +def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None): + """ + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. + + out: (batch, dim) or (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + dtype_in = x.dtype + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape + width = weight.shape[1] + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) + assert weight.shape == (dim, width) + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype) + copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) + return (out if activation is None else F.silu(out)).to(dtype=dtype_in) diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py new file mode 100644 index 0000000000000000000000000000000000000000..8005af233d5c21b0a58917a8a18045636c2351cb --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/causal_conv1d_varlen.py @@ -0,0 +1,86 @@ +import torch +from torch import Tensor + +import triton +import triton.language as tl + + +@triton.jit +def _causal_conv1d_varlen_states( + X, + CU_SEQLENS, + STATES, + state_len, + dim, + stride_x_seqlen, stride_x_dim, + stride_states_batch, stride_states_seqlen, stride_states_dim, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr +): + batch_idx = tl.program_id(2) + STATES += batch_idx * stride_states_batch + end_idx = tl.load(CU_SEQLENS + batch_idx + 1) + start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len) + rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N) + x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim, + mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim), + other=0) + rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M) + tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim, + x, + mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim)) + + +def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + BLOCK_M = min(triton.next_power_of_2(state_len), 16) + BLOCK_N = min(triton.next_power_of_2(dim), 256) + grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch) + with torch.cuda.device(x.device.index): + _causal_conv1d_varlen_states[grid]( + x, + cu_seqlens, + states, + state_len, + dim, + x.stride(0), x.stride(1), + states.stride(0), states.stride(2), states.stride(1), + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N + ) + return states + + +def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor: + """ + Forward pass only, does not support backward pass. + Parameters: + x: (total_tokens, dim) + cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0. + state_len: int. For each cu_seqlens, how many elements from x should be copied to the state. + If some of those elements belong to a different sequence, the value of the states will be zero. + Return: + states: (batch, dim, state_len) + """ + _, dim = x.shape + batch = cu_seqlens.shape[0] - 1 + cu_seqlens = cu_seqlens.contiguous() + states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2) + for i in range(batch): + end_idx = cu_seqlens[i + 1] + start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len) + states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T + return states diff --git a/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/cpp_functions.py b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/cpp_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..1ddb9f83ceb4e9f72754fe39340738d47b6aea1b --- /dev/null +++ b/build/torch29-cxx11-cu130-x86_64-linux/causal_conv1d/cpp_functions.py @@ -0,0 +1,96 @@ +# Copyright (c) 2024, Tri Dao. + +import torch + +from ._ops import ops + +def causal_conv1d_fwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + final_states_out: torch.Tensor | None, + silu_activation: bool, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_fwd( + x=x, + weight=weight, + bias=bias, + seq_idx=seq_idx, + initial_states=initial_states, + out=out, + final_states_out=final_states_out, + silu_activation=silu_activation, + ) + return out + + +def causal_conv1d_bwd_function( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + dout: torch.Tensor, + seq_idx: torch.Tensor | None, + initial_states: torch.Tensor | None, + dfinal_states: torch.Tensor | None, + dx: torch.Tensor | None, + return_dinitial_states: torch.Tensor, + silu_activation: bool, +) -> tuple[torch.Tensor | None]: + batch_size, dim = x.size()[:2] + width = weight.size(-1) + + if dx is None: + dx = torch.empty_like(x) + dweight = torch.zeros_like(weight, dtype=torch.float32) + dbias = None + if bias is not None: + dbias = torch.zeros_like(bias, dtype=torch.float32) + dinitial_states = None + if return_dinitial_states: + dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2) + + ops.causal_conv1d_bwd( + x=x, + weight=weight, + bias=bias, + dout=dout, + seq_idx=seq_idx, + initial_states=initial_states, + dfinal_states=dfinal_states, + dx=dx, + dweight=dweight, + dbias=dbias, + dinitial_states=dinitial_states, + silu_activation=silu_activation, + ) + + dweight = dweight.type_as(weight) + if dbias is not None: + dbias = dbias.type_as(bias) + return dx, dweight, dbias, dinitial_states + + +def causal_conv1d_update_function( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor | None, + silu_activation: bool, + cache_seqlens: torch.Tensor | None, + conv_state_indices: torch.Tensor | None, +) -> torch.Tensor: + out = torch.empty_like(x) + ops.causal_conv1d_update( + x=x, + conv_state=conv_state, + weight=weight, + bias=bias, + out=out, + silu_activation=silu_activation, + cache_seqlens=cache_seqlens, + conv_state_indices=conv_state_indices, + ) + return out