1. 目的
最近在做一个SQL操作类,大部分都是对SQLHelper做一下封装,最后希望实现一个可以执行参数化SQL的方法,SQL语句中允许像string.Format(formatStr,args)中的formatStr一样,用{0}{1}...作为参数格式,而实际将被转换为SQL参数化的格式,而SQL参数的值parameterValues最终将会被转换为合适类型的SqlParameter[]数组。
原型如下:
1 public object SqlScalar(string commandText, params object[] parameterValues)
2 {
3 if (commandText == null || commandText.Length == 0) throw new ArgumentNullException("commandText");
4 SqlCommand cmd = GetCommand(commandText);
5 if ((parameterValues != null) && (parameterValues.Length > 0))
6 {
7 SQlParameterFormatter formatter = new SQlParameterFormatter();
8 formatter.Format(commandText, parameterValues);
9 return ExecuteScalar(CommandType.Text, formatter.Sql, formatter.Parameters);
10 }
11 else
12 {
13 return ExecuteScalar(CommandType.Text, commandText, (SqlParameter[])null);
14 }
15
16 }
2 {
3 if (commandText == null || commandText.Length == 0) throw new ArgumentNullException("commandText");
4 SqlCommand cmd = GetCommand(commandText);
5 if ((parameterValues != null) && (parameterValues.Length > 0))
6 {
7 SQlParameterFormatter formatter = new SQlParameterFormatter();
8 formatter.Format(commandText, parameterValues);
9 return ExecuteScalar(CommandType.Text, formatter.Sql, formatter.Parameters);
10 }
11 else
12 {
13 return ExecuteScalar(CommandType.Text, commandText, (SqlParameter[])null);
14 }
15
16 }
调用时:
object result = SqlScalar("select Fid from students where name={0} and age={1}", "靳同学", 20);
而最终sql语句被转换为:
select Fid from students where name=@arg__0 and age=@arg__1
这样有个好处就是,可以避免SQL注入,而代码也很简洁。
2. 实现
上面的SqlScalar方法中第7行,有一个重要的SQlParameterFormatter类,有两个属性:
public string Sql,public SqlParameter[] Parameters。这个个类负责格式化传入的复合参数的sql和值,处理后并赋值给属性SQL和Parameters
最初被想写一个正则表达式实现,突然想到string.Format(),的确很像,既然MS已经公布.net framework的源码,为何不借用一下呢?调出源码来看,还真不少。实际上string.Format()是调用了StringBuilder的AppendFormat()方法,来看看StringBuilder类的AppendFormat()
Code
public StringBuilder AppendFormat(IFormatProvider provider, String format, params Object[] args) {
if (format == null || args == null) {
throw new ArgumentNullException((format==null)?"format":"args");
}
char[] chars = format.ToCharArray(0, format.Length);
int pos = 0;
int len = chars.Length;
char ch = '\x0';
ICustomFormatter cf = null;
if (provider!=null) {
cf=(ICustomFormatter)provider.GetFormat(typeof(ICustomFormatter));
}
while (true) {
int p = pos;
int i = pos;
while (pos < len) {
ch = chars[pos];
pos++;
if (ch == '}')
{
if(pos < len && chars[pos]=='}') // Treat as escape character for }}
pos++;
else
FormatError();
}
if (ch == '{')
{
if(pos < len && chars[pos]=='{') // Treat as escape character for {{
pos++;
else
{
pos--;
break;
}
}
chars[i++] = ch;
}
if (i > p) Append(chars, p, i - p);
if (pos == len) break;
pos++;
if (pos == len || (ch = chars[pos]) < '0' || ch > '9') FormatError();
int index = 0;
do {
index = index * 10 + ch - '0';
pos++;
if (pos == len) FormatError();
ch = chars[pos];
} while (ch >= '0' && ch <= '9' && index < 1000000);
if (index >= args.Length) throw new FormatException(Environment.GetResourceString("Format_IndexOutOfRange"));
while (pos < len && (ch=chars[pos]) == ' ') pos++;
bool leftJustify = false;
int width = 0;
if (ch == ',') {
pos++;
while (pos < len && chars[pos] == ' ') pos++;
if (pos == len) FormatError();
ch = chars[pos];
if (ch == '-') {
leftJustify = true;
pos++;
if (pos == len) FormatError();
ch = chars[pos];
}
if (ch < '0' || ch > '9') FormatError();
do {
width = width * 10 + ch - '0';
pos++;
if (pos == len) FormatError();
ch = chars[pos];
} while (ch >= '0' && ch <= '9' && width < 1000000);
}
while (pos < len && (ch=chars[pos]) == ' ') pos++;
Object arg = args[index];
String fmt = null;
if (ch == ':') {
pos++;
p = pos;
i = pos;
while (true) {
if (pos == len) FormatError();
ch = chars[pos];
pos++;
if (ch == '{')
{
if(pos < len && chars[pos]=='{') // Treat as escape character for {{
pos++;
else
FormatError();
}
else if (ch == '}')
{
if(pos < len && chars[pos]=='}') // Treat as escape character for }}
pos++;
else
{
pos--;
break;
}
}
chars[i++] = ch;
}
if (i > p) fmt = new String(chars, p, i - p);
}
if (ch != '}') FormatError();
pos++;
String s = null;
if (cf != null) {
s = cf.Format(fmt, arg, provider);
}
if (s==null) {
if (arg is IFormattable) {
s = ((IFormattable)arg).ToString(fmt, provider);
} else if (arg != null) {
s = arg.ToString();
}
}
if (s == null) s = String.Empty;
int pad = width - s.Length;
if (!leftJustify && pad > 0) Append(' ', pad);
Append(s);
if (leftJustify && pad > 0) Append(' ', pad);
}
return this;
}
public StringBuilder AppendFormat(IFormatProvider provider, String format, params Object[] args) {
if (format == null || args == null) {
throw new ArgumentNullException((format==null)?"format":"args");
}
char[] chars = format.ToCharArray(0, format.Length);
int pos = 0;
int len = chars.Length;
char ch = '\x0';
ICustomFormatter cf = null;
if (provider!=null) {
cf=(ICustomFormatter)provider.GetFormat(typeof(ICustomFormatter));
}
while (true) {
int p = pos;
int i = pos;
while (pos < len) {
ch = chars[pos];
pos++;
if (ch == '}')
{
if(pos < len && chars[pos]=='}') // Treat as escape character for }}
pos++;
else
FormatError();
}
if (ch == '{')
{
if(pos < len && chars[pos]=='{') // Treat as escape character for {{
pos++;
else
{
pos--;
break;
}
}
chars[i++] = ch;
}
if (i > p) Append(chars, p, i - p);
if (pos == len) break;
pos++;
if (pos == len || (ch = chars[pos]) < '0' || ch > '9') FormatError();
int index = 0;
do {
index = index * 10 + ch - '0';
pos++;
if (pos == len) FormatError();
ch = chars[pos];
} while (ch >= '0' && ch <= '9' && index < 1000000);
if (index >= args.Length) throw new FormatException(Environment.GetResourceString("Format_IndexOutOfRange"));
while (pos < len && (ch=chars[pos]) == ' ') pos++;
bool leftJustify = false;
int width = 0;
if (ch == ',') {
pos++;
while (pos < len && chars[pos] == ' ') pos++;
if (pos == len) FormatError();
ch = chars[pos];
if (ch == '-') {
leftJustify = true;
pos++;
if (pos == len) FormatError();
ch = chars[pos];
}
if (ch < '0' || ch > '9') FormatError();
do {
width = width * 10 + ch - '0';
pos++;
if (pos == len) FormatError();
ch = chars[pos];
} while (ch >= '0' && ch <= '9' && width < 1000000);
}
while (pos < len && (ch=chars[pos]) == ' ') pos++;
Object arg = args[index];
String fmt = null;
if (ch == ':') {
pos++;
p = pos;
i = pos;
while (true) {
if (pos == len) FormatError();
ch = chars[pos];
pos++;
if (ch == '{')
{
if(pos < len && chars[pos]=='{') // Treat as escape character for {{
pos++;
else
FormatError();
}
else if (ch == '}')
{
if(pos < len && chars[pos]=='}') // Treat as escape character for }}
pos++;
else
{
pos--;
break;
}
}
chars[i++] = ch;
}
if (i > p) fmt = new String(chars, p, i - p);
}
if (ch != '}') FormatError();
pos++;
String s = null;
if (cf != null) {
s = cf.Format(fmt, arg, provider);
}
if (s==null) {
if (arg is IFormattable) {
s = ((IFormattable)arg).ToString(fmt, provider);
} else if (arg != null) {
s = arg.ToString();
}
}
if (s == null) s = String.Empty;
int pad = width - s.Length;
if (!leftJustify && pad > 0) Append(' ', pad);
Append(s);
if (leftJustify && pad > 0) Append(' ', pad);
}
return this;
}
可以看出,这样处理效率还是蛮高的。
下面是我借用AppendFormat处理参数的方式来实现的SQlParameterFormatter类
Code
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Data.SqlClient;
using System.Data;
namespace Jrt.Sql
{
/// <summary>
/// 作者:代码乱了(靳如坦) 2008-09-25
/// </summary>
public class SQlParameterFormatter
{
private string sql;
public string Sql
{
get
{
return sql;
}
}
private SqlParameter[] parameters;
public SqlParameter[] Parameters
{
get
{
return parameters;
}
}
private static void FormatError()
{
throw new FormatException("参数格式错误");
}
public void Format(String format, params Object[] args)
{
if (format == null || args == null)
{
throw new ArgumentNullException((format == null) ? "format" : "args");
}
List<SqlParameter> commandParameters = new List<SqlParameter>();
StringBuilder sb = new StringBuilder();
char[] chars = format.ToCharArray(0, format.Length);
int pos = 0;
int len = chars.Length;
char ch = '\x0';
while (true)
{
int p = pos;
int i = pos;
while (pos < len)
{
ch = chars[pos];
pos++;
if (ch == '}')
{
if (pos < len && chars[pos] == '}') // Treat as escape character for }}
pos++;
else
FormatError();
}
if (ch == '{')
{
if (pos < len && chars[pos] == '{') // Treat as escape character for {{
pos++;
else
{
pos--;
break;
}
}
chars[i++] = ch;
}
if (i > p) sb.Append(chars, p, i - p);
if (pos == len) break;
pos++;
if (pos == len || (ch = chars[pos]) < '0' || ch > '9') FormatError();
int index = 0;
do
{
index = index * 10 + ch - '0';
pos++;
if (pos == len) FormatError();
ch = chars[pos];
} while (ch >= '0' && ch <= '9' && index < 1000000);
if (index >= args.Length) throw new FormatException("索引(从零开始)必须大于或等于零,且小于参数列表的大小。");
while (pos < len && (ch = chars[pos]) == ' ') pos++;
bool leftJustify = false;
int width = 0;
if (ch == ',')
{
pos++;
while (pos < len && chars[pos] == ' ') pos++;
if (pos == len) FormatError();
ch = chars[pos];
if (ch == '-')
{
leftJustify = true;
pos++;
if (pos == len) FormatError();
ch = chars[pos];
}
if (ch < '0' || ch > '9') FormatError();
do
{
width = width * 10 + ch - '0';
pos++;
if (pos == len) FormatError();
ch = chars[pos];
} while (ch >= '0' && ch <= '9' && width < 1000000);
}
while (pos < len && (ch = chars[pos]) == ' ') pos++;
Object arg = args[index];
String fmt = null;
if (ch == ':')
{
pos++;
p = pos;
i = pos;
while (true)
{
if (pos == len) FormatError();
ch = chars[pos];
pos++;
if (ch == '{')
{
if (pos < len && chars[pos] == '{') // Treat as escape character for {{
pos++;
else
FormatError();
}
else if (ch == '}')
{
if (pos < len && chars[pos] == '}') // Treat as escape character for }}
pos++;
else
{
pos--;
break;
}
}
chars[i++] = ch;
}
if (i > p) fmt = new String(chars, p, i - p);
}
if (ch != '}') FormatError();
pos++;
if (arg==null)
{
arg = DBNull.Value;
}
String parameterName = null;
parameterName = "@arg__" + index.ToString();
SqlParameter para = new SqlParameter();
para.ParameterName = parameterName;
para.SqlDbType = ConvertSqlType(arg.GetType());
para.Value = arg;
commandParameters.Add(para);
int pad = width - parameterName.Length;
if (!leftJustify && pad > 0) sb.Append(' ', pad);
sb.Append(parameterName);
if (leftJustify && pad > 0) sb.Append(' ', pad);
}
this.parameters = commandParameters.ToArray();
this.sql = sb.ToString();
}
public static SqlDbType ConvertSqlType(Type type)
{
switch (type.FullName.ToLower())
{
case "system.int64":
case "system.uint64":
return SqlDbType.BigInt;
case "system.boolean":
return SqlDbType.Bit;
case "system.datetime":
return SqlDbType.DateTime;
case "system.decimal":
return SqlDbType.Decimal;
case "system.double":
return SqlDbType.Float;
case "system.int32":
return SqlDbType.Int;
case "system.single":
return SqlDbType.Real;
case "system.int16":
return SqlDbType.SmallInt;
case "system.byte":
return SqlDbType.TinyInt;
case "system.sbyte":
return SqlDbType.Bit;
case "system.guid":
return SqlDbType.UniqueIdentifier;
case "system.byte()":
return SqlDbType.VarBinary;
case "system.string":
case "system.text":
return SqlDbType.VarChar;
case "system.char":
return SqlDbType.Char;
case "system.object":
return SqlDbType.Variant;
default:
throw new ArgumentOutOfRangeException();
}
}
}
}
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Data.SqlClient;
using System.Data;
namespace Jrt.Sql
{
/// <summary>
/// 作者:代码乱了(靳如坦) 2008-09-25
/// </summary>
public class SQlParameterFormatter
{
private string sql;
public string Sql
{
get
{
return sql;
}
}
private SqlParameter[] parameters;
public SqlParameter[] Parameters
{
get
{
return parameters;
}
}
private static void FormatError()
{
throw new FormatException("参数格式错误");
}
public void Format(String format, params Object[] args)
{
if (format == null || args == null)
{
throw new ArgumentNullException((format == null) ? "format" : "args");
}
List<SqlParameter> commandParameters = new List<SqlParameter>();
StringBuilder sb = new StringBuilder();
char[] chars = format.ToCharArray(0, format.Length);
int pos = 0;
int len = chars.Length;
char ch = '\x0';
while (true)
{
int p = pos;
int i = pos;
while (pos < len)
{
ch = chars[pos];
pos++;
if (ch == '}')
{
if (pos < len && chars[pos] == '}') // Treat as escape character for }}
pos++;
else
FormatError();
}
if (ch == '{')
{
if (pos < len && chars[pos] == '{') // Treat as escape character for {{
pos++;
else
{
pos--;
break;
}
}
chars[i++] = ch;
}
if (i > p) sb.Append(chars, p, i - p);
if (pos == len) break;
pos++;
if (pos == len || (ch = chars[pos]) < '0' || ch > '9') FormatError();
int index = 0;
do
{
index = index * 10 + ch - '0';
pos++;
if (pos == len) FormatError();
ch = chars[pos];
} while (ch >= '0' && ch <= '9' && index < 1000000);
if (index >= args.Length) throw new FormatException("索引(从零开始)必须大于或等于零,且小于参数列表的大小。");
while (pos < len && (ch = chars[pos]) == ' ') pos++;
bool leftJustify = false;
int width = 0;
if (ch == ',')
{
pos++;
while (pos < len && chars[pos] == ' ') pos++;
if (pos == len) FormatError();
ch = chars[pos];
if (ch == '-')
{
leftJustify = true;
pos++;
if (pos == len) FormatError();
ch = chars[pos];
}
if (ch < '0' || ch > '9') FormatError();
do
{
width = width * 10 + ch - '0';
pos++;
if (pos == len) FormatError();
ch = chars[pos];
} while (ch >= '0' && ch <= '9' && width < 1000000);
}
while (pos < len && (ch = chars[pos]) == ' ') pos++;
Object arg = args[index];
String fmt = null;
if (ch == ':')
{
pos++;
p = pos;
i = pos;
while (true)
{
if (pos == len) FormatError();
ch = chars[pos];
pos++;
if (ch == '{')
{
if (pos < len && chars[pos] == '{') // Treat as escape character for {{
pos++;
else
FormatError();
}
else if (ch == '}')
{
if (pos < len && chars[pos] == '}') // Treat as escape character for }}
pos++;
else
{
pos--;
break;
}
}
chars[i++] = ch;
}
if (i > p) fmt = new String(chars, p, i - p);
}
if (ch != '}') FormatError();
pos++;
if (arg==null)
{
arg = DBNull.Value;
}
String parameterName = null;
parameterName = "@arg__" + index.ToString();
SqlParameter para = new SqlParameter();
para.ParameterName = parameterName;
para.SqlDbType = ConvertSqlType(arg.GetType());
para.Value = arg;
commandParameters.Add(para);
int pad = width - parameterName.Length;
if (!leftJustify && pad > 0) sb.Append(' ', pad);
sb.Append(parameterName);
if (leftJustify && pad > 0) sb.Append(' ', pad);
}
this.parameters = commandParameters.ToArray();
this.sql = sb.ToString();
}
public static SqlDbType ConvertSqlType(Type type)
{
switch (type.FullName.ToLower())
{
case "system.int64":
case "system.uint64":
return SqlDbType.BigInt;
case "system.boolean":
return SqlDbType.Bit;
case "system.datetime":
return SqlDbType.DateTime;
case "system.decimal":
return SqlDbType.Decimal;
case "system.double":
return SqlDbType.Float;
case "system.int32":
return SqlDbType.Int;
case "system.single":
return SqlDbType.Real;
case "system.int16":
return SqlDbType.SmallInt;
case "system.byte":
return SqlDbType.TinyInt;
case "system.sbyte":
return SqlDbType.Bit;
case "system.guid":
return SqlDbType.UniqueIdentifier;
case "system.byte()":
return SqlDbType.VarBinary;
case "system.string":
case "system.text":
return SqlDbType.VarChar;
case "system.char":
return SqlDbType.Char;
case "system.object":
return SqlDbType.Variant;
default:
throw new ArgumentOutOfRangeException();
}
}
}
}
只是一个想法而已,没有用在实际项目中, 欢迎大家拍砖