Skip to content

Commit

Permalink
Merge pull request #6 from sappho192/dev_encode
Browse files Browse the repository at this point in the history
Fixes #3 ; Implement encode() function
  • Loading branch information
sappho192 authored Jun 18, 2024
2 parents f45946a + 2eb6968 commit f553a0d
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 10 deletions.
2 changes: 1 addition & 1 deletion NATIVE_LIB_VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.6.1
1.0.0
27 changes: 25 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@

* [X] Download tokenizer files from Hugginface Hub
* [X] Load tokenizer file(`.json`) from local
* [X] Decode embeddings to string
* [X] Encode string to tokens
* [X] Decode tokens to string

# How to use

Expand All @@ -50,11 +51,33 @@ Console.WriteLine($"Downloaded {fileFullPath}");

// Create a tokenizer instance
var tokenizer = new Tokenizer(vocabPath: fileFullPath);
var tokens = new uint[] { 9330, 387, 12857, 9376, 18649, 9098, 7656, 6969, 8084, 1 };
var text = "음, 이제 식사도 해볼까요";
Console.WriteLine($"Input text: {text}");
var tokens = tokenizer.Encode(text);
Console.WriteLine($"Encoded: {string.Join(", ", tokens)}");
var decoded = tokenizer.Decode(tokens);
Console.WriteLine($"Decoded: {decoded}");

Console.WriteLine($"Version of Tokenizers.DotNet.runtime.win: {tokenizer.GetVersion()}");

Console.WriteLine("--------------------------------------------------");
// Use another tokenizer
//// Download openai-community/gpt2 from the hub
hubName = "openai-community/gpt2";
filePath = "tokenizer.json";
fileFullPath = await HuggingFace.GetFileFromHub(hubName, filePath, "deps");

// Create a tokenizer instance
var tokenizer2 = new Tokenizer(vocabPath: fileFullPath);
var text2 = "i was nervous before the exam, and i had a fever.";
Console.WriteLine($"Input text: {text2}");
var tokens2 = tokenizer2.Encode(text2);
Console.WriteLine($"Encoded: {string.Join(", ", tokens2)}");
var decoded2 = tokenizer2.Decode(tokens2);
Console.WriteLine($"Decoded: {decoded2}");

Console.WriteLine($"Version of Tokenizers.DotNet.runtime.win: {tokenizer2.GetVersion()}");
Console.ReadKey();
```

# How to build
Expand Down
2 changes: 1 addition & 1 deletion dotnet/ConsoleExample/ConsoleExample.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
</PropertyGroup>

<ItemGroup>
<PackageReference Include="Tokenizers.DotNet.runtime.win" Version="0.6.1" />
<PackageReference Include="Tokenizers.DotNet.runtime.win" Version="0.7.0" />
</ItemGroup>

<ItemGroup>
Expand Down
15 changes: 11 additions & 4 deletions dotnet/ConsoleExample/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,24 +8,31 @@

// Create a tokenizer instance
var tokenizer = new Tokenizer(vocabPath: fileFullPath);
var tokens = new uint[] { 9330, 387, 12857, 9376, 18649, 9098, 7656, 6969, 8084, 1 };
var text = "음, 이제 식사도 해볼까요";
Console.WriteLine($"Input text: {text}");
var tokens = tokenizer.Encode(text);
Console.WriteLine($"Encoded: {string.Join(", ", tokens)}");
var decoded = tokenizer.Decode(tokens);
Console.WriteLine($"Decoded: {decoded}");

Console.WriteLine($"Version of Tokenizers.DotNet.runtime.win: {tokenizer.GetVersion()}");

// Download openai-community/gpt2 from the hub
Console.WriteLine("--------------------------------------------------");

//// Download openai-community/gpt2 from the hub
hubName = "openai-community/gpt2";
filePath = "tokenizer.json";
fileFullPath = await HuggingFace.GetFileFromHub(hubName, filePath, "deps");

// Create a tokenizer instance
var tokenizer2 = new Tokenizer(vocabPath: fileFullPath);
var tokens2 = new uint[] { 72, 373, 10927, 878, 262, 2814, 11, 290, 1312, 550, 257, 17372, 13 };
var text2 = "i was nervous before the exam, and i had a fever.";
Console.WriteLine($"Input text: {text2}");
var tokens2 = tokenizer2.Encode(text2);
Console.WriteLine($"Encoded: {string.Join(", ", tokens2)}");
var decoded2 = tokenizer2.Decode(tokens2);
Console.WriteLine($"Decoded: {decoded2}");

Console.WriteLine($"Version of Tokenizers.DotNet.runtime.win: {tokenizer2.GetVersion()}");
Console.ReadKey();


3 changes: 3 additions & 0 deletions dotnet/Tokenizers.DotNet/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ internal static unsafe partial class NativeMethods
[DllImport(__DllName, EntryPoint = "tokenizer_initialize", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern ByteBuffer* tokenizer_initialize(ushort* utf16_path, int utf16_path_len);

[DllImport(__DllName, EntryPoint = "tokenizer_encode", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern ByteBuffer* tokenizer_encode(ushort* _session_id, int _session_id_len, ushort* _text, int _text_len);

[DllImport(__DllName, EntryPoint = "tokenizer_decode", CallingConvention = CallingConvention.Cdecl, ExactSpelling = true)]
public static extern ByteBuffer* tokenizer_decode(ushort* _session_id, int _session_id_len, uint* _token_ids, int _token_ids_len);

Expand Down
15 changes: 15 additions & 0 deletions dotnet/Tokenizers.DotNet/Tokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,21 @@ public Tokenizer(string vocabPath)
}
}

public uint[] Encode(string text)
{
unsafe
{
fixed (char* p = sessionId)
{
fixed (char* pt = text)
{
var tokensRaw = NativeMethods.tokenizer_encode((ushort*)p, sessionId.Length, (ushort*)pt, text.Length);
var tokens = tokensRaw->AsSpan<uint>();
return tokens.ToArray();
}
}
}
}

public string Decode(uint[] tokens)
{
Expand Down
2 changes: 1 addition & 1 deletion nuget/win-x64/Tokenizers.DotNet.runtime.win.nuspec
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
<package >
<metadata>
<id>Tokenizers.DotNet.runtime.win</id>
<version>0.6.1</version>
<version>1.0.0</version>
<title>Tokenizers.DotNet.runtime.win</title>
<authors>sappho192</authors>
<owners>sappho192</owners>
Expand Down
2 changes: 1 addition & 1 deletion rust/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "hf_tokenizers"
version = "0.6.1"
version = "1.0.0"
edition = "2021"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
Expand Down
35 changes: 35 additions & 0 deletions rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,41 @@ pub unsafe extern "C" fn tokenizer_initialize(
Box::into_raw(Box::new(ByteBuffer::from_vec(session_id.into_bytes())))
}

#[no_mangle]
pub unsafe extern "C" fn tokenizer_encode(
_session_id: *const u16,
_session_id_len: i32,
_text: *const u16,
_text_len: i32,
) -> *mut ByteBuffer {
let slice_session_id = std::slice::from_raw_parts(_session_id, _session_id_len as usize);
let session_id = String::from_utf16(slice_session_id).unwrap();
let slice_text = std::slice::from_raw_parts(_text, _text_len as usize);
let text = String::from_utf16(slice_text).unwrap();

// Retrieve the tokenizer associated with the session ID
let tokenizer = TOKENIZER_DB
.get(&session_id)
.cloned()
.unwrap_or_else(|| panic!("Tokenizer for session ID '{}' not found.", session_id));

// Encode the text
let encoded_result = tokenizer.encode(text.clone(), true);
let encoded_tokens = match encoded_result {
Ok(encoded) => encoded,
Err(err) => panic!("{}", err),
};
let token_ids = encoded_tokens
.get_ids()
.iter()
.map(|&i| i as u32)
.collect::<Vec<u32>>();

// Convert the token IDs to a ByteBuffer
let buf = ByteBuffer::from_vec_struct(token_ids);
Box::into_raw(Box::new(buf))
}

// Returns u8string. Caller must free the memory
#[no_mangle]
pub unsafe extern "C" fn tokenizer_decode(
Expand Down

0 comments on commit f553a0d

Please sign in to comment.