Skip to content
This repository has been archived by the owner on Oct 5, 2023. It is now read-only.

Initial commit of C# port - does not build #184

Draft
wants to merge 1 commit into
base: develop
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions ports/win32-csharp/src/App.config
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
<?xml version="1.0" encoding="utf-8"?>
<configuration>
<startup>
<supportedRuntime version="v4.0" sku=".NETFramework,Version=v4.7.1" />
</startup>
<runtime>
<assemblyBinding xmlns="urn:schemas-microsoft-com:asm.v1">
<dependentAssembly>
<assemblyIdentity name="System.Memory" publicKeyToken="cc7b13ffcd2ddd51" culture="neutral" />
<bindingRedirect oldVersion="0.0.0.0-4.0.1.1" newVersion="4.0.1.1" />
</dependentAssembly>
<dependentAssembly>
<assemblyIdentity name="System.Numerics.Vectors" publicKeyToken="b03f5f7f11d50a3a" culture="neutral" />
<bindingRedirect oldVersion="0.0.0.0-4.1.4.0" newVersion="4.1.4.0" />
</dependentAssembly>
<dependentAssembly>
<assemblyIdentity name="System.Runtime.CompilerServices.Unsafe" publicKeyToken="b03f5f7f11d50a3a" culture="neutral" />
<bindingRedirect oldVersion="0.0.0.0-4.0.6.0" newVersion="4.0.6.0" />
</dependentAssembly>
<dependentAssembly>
<assemblyIdentity name="System.Buffers" publicKeyToken="cc7b13ffcd2ddd51" culture="neutral" />
<bindingRedirect oldVersion="0.0.0.0-4.0.3.0" newVersion="4.0.3.0" />
</dependentAssembly>
<dependentAssembly>
<assemblyIdentity name="System.Collections.Immutable" publicKeyToken="b03f5f7f11d50a3a" culture="neutral" />
<bindingRedirect oldVersion="0.0.0.0-1.2.5.0" newVersion="1.2.5.0" />
</dependentAssembly>
<dependentAssembly>
<assemblyIdentity name="Newtonsoft.Json" publicKeyToken="30ad4fe6b2a6aeed" culture="neutral" />
<bindingRedirect oldVersion="0.0.0.0-12.0.0.0" newVersion="12.0.0.0" />
</dependentAssembly>
<dependentAssembly>
<assemblyIdentity name="System.Threading.Tasks.Dataflow" publicKeyToken="b03f5f7f11d50a3a" culture="neutral" />
<bindingRedirect oldVersion="0.0.0.0-4.6.5.0" newVersion="4.6.5.0" />
</dependentAssembly>
<dependentAssembly>
<assemblyIdentity name="TensorFlow.NET" publicKeyToken="cc7b13ffcd2ddd51" culture="neutral" />
<bindingRedirect oldVersion="0.0.0.0-0.13.0.0" newVersion="0.13.0.0" />
</dependentAssembly>
<dependentAssembly>
<assemblyIdentity name="System.Security.Principal.Windows" publicKeyToken="b03f5f7f11d50a3a" culture="neutral" />
<bindingRedirect oldVersion="0.0.0.0-4.1.3.0" newVersion="4.1.3.0" />
</dependentAssembly>
<dependentAssembly>
<assemblyIdentity name="System.IO.FileSystem.AccessControl" publicKeyToken="b03f5f7f11d50a3a" culture="neutral" />
<bindingRedirect oldVersion="0.0.0.0-4.0.5.0" newVersion="4.0.5.0" />
</dependentAssembly>
<dependentAssembly>
<assemblyIdentity name="System.Security.AccessControl" publicKeyToken="b03f5f7f11d50a3a" culture="neutral" />
<bindingRedirect oldVersion="0.0.0.0-4.1.3.0" newVersion="4.1.3.0" />
</dependentAssembly>
<dependentAssembly>
<assemblyIdentity name="Google.Protobuf" publicKeyToken="a7d26565bac4d604" culture="neutral" />
<bindingRedirect oldVersion="0.0.0.0-3.11.1.0" newVersion="3.11.1.0" />
</dependentAssembly>
</assemblyBinding>
</runtime>
</configuration>
94 changes: 94 additions & 0 deletions ports/win32-csharp/src/Encoder.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Newtonsoft.Json;
using Newtonsoft.Json.Serialization;
using Newtonsoft.Json.Utilities;
using System.IO;
using System.Text.RegularExpressions;

namespace AIDungeon.net
{
public class Encoder
{
public IDictionary<string, int> EncodeDict { get; protected set; }
public IDictionary<int, string> DecodeDict { get; protected set; }
public string Errors { get; set; }
public IDictionary<int, char> ByteEncoder { get; set; }
public IDictionary<char, int> ByteDecoder { get; set; }
public IDictionary<Tuple<string, string>, int> BpeRanks { get; set; }
public IDictionary<string, string> Cache { get; set; }
public Regex Pattern { get; set; }
public Encoder(Dictionary<string, int> encoder, IEnumerable<Tuple<string, string>> bpeMerges, string errors = "replace")
{
EncodeDict = encoder;
DecodeDict = encoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
Errors = errors;
ByteEncoder = BytesToUnicode();
ByteDecoder = ByteEncoder.ToDictionary(kvp => kvp.Value, kvp => kvp.Key);
BpeRanks = bpeMerges.Select(BpeConverter).ToDictionary(kvp => kvp.Key, kvp => kvp.Value);
Cache = new Dictionary<string, string>();
Pattern = new Regex(@"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+");
}

private IDictionary<int, char> BytesToUnicode()
{
const int start1 = '!';
const int end1 = '~';
const int start2 = '¡';
const int end2 = '¬';
const int start3 = '®';
const int end3 = 'ÿ';

var dict = new Dictionary<int, char>();
for (var i = start1; i <= end1; i++)
{
dict[i] = Convert.ToChar(i);
}
for (var i = start2; i <= end2; i++)
{
dict[i] = Convert.ToChar(i);
}
for (var i = start3; i <= end3; i++)
{
dict[i] = Convert.ToChar(i);
}

var n = 0;
for (var b = 0; b <= 0xFF; b++)
{
if (dict.ContainsKey(b)) continue;
dict[b] = Convert.ToChar(0x100 + n++);
}

return dict;
}

private static KeyValuePair<Tuple<string, string>, int> BpeConverter(Tuple<string, string> tuple, int idx) =>
new KeyValuePair<Tuple<string, string>, int>(tuple, idx);

public static Encoder GetEncoder(string modelName, string modelsDir)
{
var encoder =
JsonConvert.DeserializeObject<Dictionary<string, int>>(
File.ReadAllText(Path.Combine(modelsDir, modelName, "encoder.json")));

var bpeMerges = new List<Tuple<string, string>>(File
.ReadAllLines(Path.Combine(modelsDir, modelName, "vocab.bpe")).Skip(1).Select(SplitWhitespace)
.Where(NotNull));

return new Encoder(encoder, bpeMerges);
}

private static Tuple<string, string> SplitWhitespace(string input)
{
var output = input.Split(' ');
if (output.Length != 2) return null;
return new Tuple<string, string>(output[0], output[1]);
}

private static bool NotNull(Tuple<string, string> input) => input != null;
}
}
40 changes: 40 additions & 0 deletions ports/win32-csharp/src/Form1.Designer.cs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 39 additions & 0 deletions ports/win32-csharp/src/Form1.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using System;
using System.Collections.Generic;
using System.ComponentModel;
using System.Data;
using System.Drawing;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using System.Windows.Forms;
using Tensorflow;
using Tensorflow.Contrib.Train;
using static Tensorflow.Binding;

namespace AIDungeon.net
{
public partial class Form1 : Form
{
public Form1()
{
InitializeComponent();

var config = new ConfigProto {GpuOptions = new GPUOptions {AllowGrowth = true}};
var sess = tf.Session(config);
var context = tf.placeholder(tf.int32, new TensorShape(1));

var output = SampleSequence()

var saver = tf.train.Saver();
var ckpt = tf.train.latest_checkpoint(@"F:\src\AIDungeon\generator\gpt2\models\model_v5");
saver.restore(sess, ckpt);
}

public object SampleSequence(HParams hparams, int length, string start_token, int batch_size, Tensor context,
int temperature, int top_k, int top_p)
{

}
}
}
66 changes: 66 additions & 0 deletions ports/win32-csharp/src/GPT2Generator.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Newtonsoft.Json;
using Tensorflow;
using static Tensorflow.Binding;

namespace AIDungeon.net
{
public class GPT2Generator
{
public static readonly Random rng = new Random();

public int GenerateNum { get; set; }
public decimal Temp { get; set; }
public decimal TopK { get; set; }
public decimal TopP { get;set; }
public bool Censor { get; set; }
public string ModelName { get; set; }
public string ModelDir { get; set; }
public string CheckpointPath { get; set; }
public int BatchSize { get; set; }
public int Samples { get; set; }
public Encoder Encoder { get; set; }
public Session Session { get; set; }
public Tensor Context { get; set; }
public object Output { get; set; }


public GPT2Generator(int generateNum, decimal temperature, decimal topK, decimal topP, bool censor)
{
GenerateNum = generateNum;
Temp = temperature;
TopK = topK;
TopP = topP;
Censor = censor;

ModelName = "model_v5";
ModelDir = "generator/gpt2/models";
CheckpointPath = Path.Combine(ModelDir, ModelName);

var modelsDir = Environment.ExpandEnvironmentVariables(ModelDir);
BatchSize = 1;
Samples = 1;

Encoder = Encoder.GetEncoder(ModelName, modelsDir);
var hParams =
JsonConvert.DeserializeObject<Dictionary<string, int>>(
File.ReadAllText(Path.Combine(modelsDir, ModelName, "hparams.json")));
foreach (var kvp in GPT2Model.DefaultHParams())
{
if (!hParams.ContainsKey(kvp.Key)) hParams[kvp.Key] = kvp.Value;
}

var seed = rng.Next(100001);

var config = new ConfigProto {GpuOptions = new GPUOptions {AllowGrowth = true}};
Session = tf.Session(config);
Context = tf.placeholder(tf.int32, new TensorShape(BatchSize));
Output = Sample.SampleSequence();
}
}
}
58 changes: 58 additions & 0 deletions ports/win32-csharp/src/GPT2Model.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Threading.Tasks;
using Tensorflow;
using static Tensorflow.Binding;

namespace AIDungeon.net
{
public class GPT2Model
{
public static Dictionary<string, int> DefaultHParams()
{
return new Dictionary<string, int>
{{"n_vocab", 0}, {"n_ctx", 1024}, {"n_embd", 768}, {"n_head", 12}, {"n_layer", 12}};
}

public static object Model(IDictionary<string, int> hParams, Tensor x, Tensor past, string scope, _ReuseMode autoReuse)
{
using (var varScope = tf.variable_scope(name: scope, reuse: autoReuse == _ReuseMode.AUTO_REUSE))
{
var results = new Dictionary<string, object>();
(int batch, Tensor sequence) = ShapeList(x);

var wpe = tf.get_variable("wpe", new TensorShape(hParams["n_ctx"], hParams["n_embd"]),
initializer: tf.random_normal_initializer(stddev: 0.01F));

var wte = tf.get_variable("wte", new TensorShape(hParams["n_vocab"], hParams["n_embd"]),
initializer: tf.random_normal_initializer(stddev: 0.02F));

var past_length = past == null ? new Tensor(0) : tf.shape(past)[past.shape.Length - 1];

var h = tf.gather(wte, x) + tf.gather(wpe, PositionsFor(x, past_length));
}
}

private static Tensor PositionsFor(Tensor tokens, Tensor pastLength)
{
var batch_size = tf.shape(tokens)[0];
var nSteps = tf.shape(tokens)[1];
return ExpandTile(pastLength + tf.range(nSteps), batch_size);
}

private static Tensor ExpandTile(Tensor value, Tensor size)
{
var outValue = tf.convert_to_tensor(value, name: "value");
var nDims = outValue.shape.Length;
var multiples = new Tensor()
return tf.tile<Tensor>(tf.expand_dims(value, axis: 0), size + new Tensor(1) * nDims);
}

private static (int batch, Tensor sequence) ShapeList(Tensor x)
{
return (x.shape[0], tf.shape(x)[1]);
}
}
}
Loading