前言
java中的ArrayList和LinkedList都是我们很常用的数据结构,了解它们的内部实现原理可以让我们更好的使用它们。
代码实现
ArrayList
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.ListIterator;
import java.util.Objects;
import java.util.function.Predicate;
/**
* 实现一个自己的ArrayList
*
* @param <E> 元素类型
*/
public class MyArrayList<E> implements List<E> {
/**
* 数据容器
*/
private Object[] data;
/**
* 实际容量
*/
private int size;
public MyArrayList() {
this(10);
}
public MyArrayList(int capacity) {
data = new Object[capacity];
}
/**
* 将一个元素添加到指定索引
*
* @param index 索引
* @param e 元素
*/
@Override
public void add(int index, E e) {
rangeCheckForAdd(index);
int oldCapacity = data.length;
if (oldCapacity == size) {
resize(oldCapacity + (oldCapacity >> 1));
}
System.arraycopy(data, index, data, index + 1, size - index);
data[index] = e;
size++;
}
/**
* 根据索引删除一个元素
*
* @param index 索引
* @return 删除的元素
*/
@Override
public E remove(int index) {
Objects.checkIndex(index, size);
E oldValue = elementData(index);
fastRemove(index);
return oldValue;
}
/**
* 查询元素在容器中索引(从前往后)
*
* @param o 元素
* @return 索引 不存在返回-1
*/
@Override
public int indexOf(Object o) {
for (int i = 0; i < size; i++) {
if (Objects.equals(data[i], o)) {
return i;
}
}
return -1;
}
/**
* 查询元素在容器中索引(总后往前)
*
* @param o 元素
* @return 索引
*/
@Override
public int lastIndexOf(Object o) {
for (int i = size - 1; i >= 0; i--) {
if (Objects.equals(data[i], o)) {
return i;
}
}
return -1;
}
/**
* 创建迭代器
*
* @return 迭代器
*/
@Override
public ListIterator<E> listIterator() {
return new MyArrayListListIterator(0);
}
/**
* 创建迭代器
*
* @param index 开始索引
* @return 迭代器
*/
@Override
public ListIterator<E> listIterator(int index) {
return new MyArrayListListIterator(index);
}
/**
* 创建一个容器的视图
*
* @param fromIndex 开始索引
* @param toIndex 结束索引
* @return 容器视图
*/
@Override
public List<E> subList(int fromIndex, int toIndex) {
subListRangeCheck(fromIndex, toIndex, size);
List<E> subList = new MyArrayList<>();
for (int i = fromIndex; i < toIndex; i++) {
subList.add(elementData(i));
}
return subList;
}
/**
* 添加元素
*
* @param e 元素
* @return 添加是否成功
*/
@Override
public boolean add(E e) {
add(size, e);
return true;
}
/**
* 删除元素
*
* @param o 元素
* @return 是否成功
*/
@Override
public boolean remove(Object o) {
int index = indexOf(o);
if (index > -1) {
fastRemove(index);
return true;
}
return false;
}
/**
* 是否包含指定容器中的所有元素
*
* @param c 容器
* @return 是否包含
*/
@Override
public boolean containsAll(Collection<?> c) {
for (Object e : c) {
if (!contains(e)) {
return false;
}
}
return true;
}
/**
* 将指定容器中元素全部添加到该容器中
*
* @param c 容器
* @return 是否成功
*/
@Override
public boolean addAll(Collection<? extends E> c) {
return addAll(size, c);
}
/**
* 将指定容器元素添加到指定索引
*
* @param index 索引
* @param c 容器
* @return 是否成功
*/
@Override
public boolean addAll(int index, Collection<? extends E> c) {
rangeCheckForAdd(index);
int oldCapacity = data.length;
int newSize = size + c.size();
if (newSize > oldCapacity) {
int newCapacity = oldCapacity;
while (newSize > newCapacity) {
newCapacity = newCapacity + (newCapacity >> 1);
}
resize(newCapacity);
}
System.arraycopy(data, index, data, index + c.size(), size - index);
System.arraycopy(c.toArray(), 0, data, index, c.size());
size += c.size();
return true;
}
/**
* 删除指定容器中的所有元素
*
* @param c 容器
* @return 是否成功
*/
@Override
public boolean removeAll(Collection<?> c) {
batchRemove(item -> !c.contains(item));
return true;
}
/**
* 保留指定容器中的所有元素,其余的删除
*
* @param c 容器
* @return 是否成功
*/
@Override
public boolean retainAll(Collection<?> c) {
batchRemove(c::contains);
return true;
}
/**
* 删除满足指定条件的元素
*
* @param filter 删除条件
* @return 是否成功
*/
@Override
public boolean removeIf(Predicate<? super E> filter) {
batchRemove(filter.negate());
return true;
}
/**
* 清空容器
*/
@Override
public void clear() {
for (int i = 0; i < size; i++) {
data[i] = null;
}
size = 0;
}
/**
* 修改指定索引的元素
*
* @param index 索引
* @param e 元素
* @return 原来的元素
*/
@Override
public E set(int index, E e) {
Objects.checkIndex(index, size);
E oldValue = elementData(index);
data[index] = e;
return oldValue;
}
/**
* 获取执行索引的元素
*
* @param index 索引
* @return 元素
*/
@Override
public E get(int index) {
Objects.checkIndex(index, size);
return elementData(index);
}
/**
* 查询容器容量
*
* @return 容量
*/
@Override
public int size() {
return size;
}
/**
* 容器是否为空
*
* @return 是否为空
*/
@Override
public boolean isEmpty() {
return size == 0;
}
/**
* 容器是否包含指定元素
*
* @param o 元素
* @return 是否包含
*/
@Override
public boolean contains(Object o) {
return indexOf(o) >= 0;
}
/**
* 创建迭代器
*/
@Override
public Iterator<E> iterator() {
return new MyArrayListIterator();
}
/**
* 将容器转换成数组
*
* @return 数组
*/
@Override
public Object[] toArray() {
return Arrays.copyOf(data, size);
}
/**
* 将容器转换成指定类型的数组
*
* @param a 指定数组
* @param <T> 数组元素类型
*/
@Override
public <T> T[] toArray(T[] a) {
if (a.length < size) {
return (T[]) Arrays.copyOf(data, size, a.getClass());
}
System.arraycopy(data, 0, a, 0, size);
if (a.length > size) {
a[size] = null;
}
return a;
}
@Override
public String toString() {
return Arrays.toString(toArray());
}
private void fastRemove(int index) {
System.arraycopy(data, index + 1, data, index, size - index - 1);
data[size] = null;
size--;
}
private void batchRemove(Predicate<? super E> filter) {
int low = 0;
int high = 0;
for (; high < size; high++) {
if (filter.test(elementData(high))) {
data[low++] = data[high];
}
}
for (int i = low; i < high; i++) {
data[i] = null;
}
size -= high - low;
}
private void resize(int newCapacity) {
Object[] newData = new Object[newCapacity];
System.arraycopy(data, 0, newData, 0, size);
data = newData;
}
private E elementData(int index) {
return (E) data[index];
}
private void rangeCheckForAdd(int index) {
if (index > size || index < 0) {
throw new IndexOutOfBoundsException(outOfBoundsMsg(index));
}
}
private void subListRangeCheck(int fromIndex, int toIndex, int size) {
if (fromIndex < 0) {
throw new IndexOutOfBoundsException("fromIndex = " + fromIndex);
}
if (toIndex > size) {
throw new IndexOutOfBoundsException("toIndex = " + toIndex);
}
if (fromIndex > toIndex) {
throw new IllegalArgumentException("fromIndex(" + fromIndex +
") > toIndex(" + toIndex + ")");
}
}
private String outOfBoundsMsg(int index) {
return "Index: " + index + ", Size: " + size;
}
private class MyArrayListIterator implements Iterator<E> {
int cursor;
@Override
public boolean hasNext() {
return cursor != size;
}
@Override
public E next() {
return elementData(cursor++);
}
@Override
public void remove() {
MyArrayList.this.remove(cursor);
}
}
private class MyArrayListListIterator extends MyArrayListIterator implements ListIterator<E> {
MyArrayListListIterator(int index) {
super();
cursor = index;
}
@Override
public boolean hasPrevious() {
return cursor != 0;
}
@Override
public E previous() {
return elementData(--cursor);
}
@Override
public int nextIndex() {
return cursor;
}
@Override
public int previousIndex() {
return cursor - 1;
}
@Override
public void set(E e) {
MyArrayList.this.set(cursor, e);
}
@Override
public void add(E e) {
MyArrayList.this.add(cursor, e);
}
}
}
LinkedList
/**
* 实现一个自己的LinkedList
*
* @param <E> 元素类型
*/
public class MyLinkedList<E> implements List<E> {
/**
* 虚拟头结点 实际头结点从下一个开始
*/
private Node<E> dummyHead;
/**
* 尾节点
*/
private Node<E> tail;
/**
* 实际容量
*/
private int size;
public MyLinkedList() {
dummyHead = new Node<>(null, null, null);
}
@Override
public void add(int index, E e) {
rangeCheckForAdd(index);
Node<E> prev = dummyHead;
for (int i = 0; i < index; i++) {
prev = prev.next;
}
Node<E> next = prev.next;
Node<E> newNode = new Node<>(e, prev, next);
prev.next = newNode;
if (index == size) {
tail = prev.next;
} else {
next.prev = newNode;
}
size++;
}
@Override
public E remove(int index) {
Objects.checkIndex(index, size);
Node<E> node = node(index);
E data = node.data;
fastRemove(node);
return data;
}
private void fastRemove(Node<E> node) {
if (node == tail) {
tail = node.prev;
}
node.prev.next = node.next;
node.data = null;
node.prev = null;
node.next = null;
size--;
}
@Override
public int indexOf(Object o) {
Node<E> cur = dummyHead.next;
for (int index = 0; cur != null; index++, cur = cur.next) {
if (Objects.equals(cur.data, o)) {
return index;
}
}
return -1;
}
@Override
public int lastIndexOf(Object o) {
Node<E> cur = tail;
for (int index = size - 1; cur != null; index--, cur = cur.prev) {
if (Objects.equals(cur.data, o)) {
return index;
}
}
return -1;
}
@Override
public ListIterator<E> listIterator() {
return listIterator(0);
}
@Override
public ListIterator<E> listIterator(int index) {
return new MyLinkedListListIterator(index);
}
@Override
public List<E> subList(int fromIndex, int toIndex) {
subListRangeCheck(fromIndex, toIndex, size);
List<E> subList = new MyArrayList<>();
for (Node<E> cur = node(fromIndex); fromIndex < toIndex; fromIndex++) {
subList.add(cur.data);
cur = cur.next;
}
return subList;
}
@Override
public boolean add(E e) {
add(size, e);
return true;
}
@Override
public boolean remove(Object o) {
Node<E> cur = dummyHead.next;
while (cur != null) {
if (Objects.equals(cur.data, o)) {
fastRemove(cur);
return true;
}
cur = cur.next;
}
return false;
}
@Override
public boolean containsAll(Collection<?> c) {
for (Object e : c) {
if (!contains(e)) {
return false;
}
}
return true;
}
@Override
public boolean addAll(Collection<? extends E> c) {
return addAll(size, c);
}
@Override
public boolean addAll(int index, Collection<? extends E> c) {
rangeCheckForAdd(index);
Object[] objects = c.toArray();
Node<E> prev = dummyHead;
for (int i = 0; i < index; i++) {
prev = prev.next;
}
Node<E> succ = prev.next;
for (Object object : objects) {
prev.next = new Node<>((E) object, prev, null);
prev = prev.next;
}
prev.next = succ;
if (index == size) {
tail = prev.next;
} else {
succ.prev = prev;
}
size += c.size();
return true;
}
@Override
public boolean removeAll(Collection<?> c) {
removeIf(c::contains);
return true;
}
@Override
public boolean retainAll(Collection<?> c) {
removeIf(o -> !c.contains(o));
return true;
}
@Override
public void clear() {
Node<E> cur = dummyHead;
while (cur != null) {
Node<E> next = cur.next;
cur.data = null;
cur.prev = null;
cur.next = null;
cur = next;
}
size = 0;
}
public E set(int index, E e) {
Objects.checkIndex(index, size);
Node<E> node = node(index);
E oldValue = node.data;
node.data = e;
return oldValue;
}
public E get(int index) {
Objects.checkIndex(index, size);
return node(index).data;
}
public int size() {
return size;
}
public boolean isEmpty() {
return size == 0;
}
@Override
public boolean contains(Object o) {
return indexOf(o) >= 0;
}
@Override
public Iterator<E> iterator() {
return listIterator();
}
@Override
public Object[] toArray() {
Object[] res = new Object[size];
Node<E> cur = dummyHead.next;
for (int index = 0; index < size; index++) {
res[index] = cur.data;
cur = cur.next;
}
return res;
}
@Override
public <T> T[] toArray(T[] a) {
if (a.length < size) {
a = (T[]) Array.newInstance(a.getClass().getComponentType(), size);
}
Object[] res = a;
Node<E> cur = dummyHead.next;
for (int index = 0; index < size; index++) {
res[index] = cur.data;
cur = cur.next;
}
if (a.length > size) {
a[size] = null;
}
return a;
}
@Override
public String toString() {
return Arrays.toString(toArray());
}
private Node<E> node(int index) {
if (index <= (size >> 1)) {
Node<E> cur = dummyHead.next;
for (int i = 0; i < index; i++) {
cur = cur.next;
}
return cur;
} else {
Node<E> cur = tail;
for (int i = size - 1; i > index; i--) {
cur = cur.prev;
}
return cur;
}
}
private String outOfBoundsMsg(int index) {
return "Index: " + index + ", Size: " + size;
}
private void rangeCheckForAdd(int index) {
if (index > size || index < 0) {
throw new IndexOutOfBoundsException(outOfBoundsMsg(index));
}
}
private void subListRangeCheck(int fromIndex, int toIndex, int size) {
if (fromIndex < 0) {
throw new IndexOutOfBoundsException("fromIndex = " + fromIndex);
}
if (toIndex > size) {
throw new IndexOutOfBoundsException("toIndex = " + toIndex);
}
if (fromIndex > toIndex) {
throw new IllegalArgumentException("fromIndex(" + fromIndex +
") > toIndex(" + toIndex + ")");
}
}
private static class Node<E> {
E data;
Node<E> prev;
Node<E> next;
Node(E data, Node<E> prev, Node<E> next) {
this.data = data;
this.prev = prev;
this.next = next;
}
}
private class MyLinkedListListIterator implements ListIterator<E> {
private int cursor;
Node<E> cur;
MyLinkedListListIterator(int index) {
super();
cursor = index;
cur = node(index);
}
@Override
public boolean hasNext() {
return cur != null;
}
@Override
public E next() {
E data = cur.data;
cur = cur.next;
cursor++;
return data;
}
@Override
public void remove() {
MyLinkedList.this.remove(cur);
}
@Override
public boolean hasPrevious() {
return cursor != 0;
}
@Override
public E previous() {
E data = cur.data;
cur = cur.prev;
cursor--;
return data;
}
@Override
public int nextIndex() {
return cursor;
}
@Override
public int previousIndex() {
return cursor - 1;
}
@Override
public void set(E e) {
MyLinkedList.this.set(cursor, e);
}
@Override
public void add(E e) {
MyLinkedList.this.add(cursor, e);
}
}
}
总结
本实现参考了jdk的ArrayList和LinkedList的实现,主要实现了增删改查,扩容更功能,待完善的地方有
- subList(),现在的实现和List接口要求不符
- 迭代器不支持快速失败
查看源码可以让我们使用工具更加得心应手。